diff --git a/.kilo/command/benchmarks-affected.md b/.kilo/command/benchmarks-affected.md deleted file mode 100644 index 158cb9e..0000000 --- a/.kilo/command/benchmarks-affected.md +++ /dev/null @@ -1,34 +0,0 @@ ---- -description: Scan current branch and report impacted benchmark targets/functions. ---- - -# Benchmarks Affected - -Identify which benchmark binaries and benchmark functions are affected by changes on the current branch. - -Use the `benchmarks-affected` skill as the single source of truth for workflow details and guardrails. -Do not duplicate or override the skill instructions in this command. - -## Inputs - -- Optional `--baseline ` (default: `main`) -- Optional `--compile-commands ` -- Optional `--no-include-working-tree` -- Optional `--format ` (default: `text`) - -## Workflow - -1. Execute the `benchmarks-affected` skill workflow. -2. Pass through command inputs to the analyzer invocation defined by the skill. -3. Report results with these sections: - - Changed files - - Affected benchmark targets - - Affected benchmark functions - - Suggested `--benchmark_filter` regex - - Any warnings/failures - -## Output rules - -1. If `affected_benchmarks` is non-empty, prioritize those names. -2. If `affected_benchmarks` is empty but benchmark targets are affected, mark result as partial and include target-level impact. -3. Do not run full benchmark suites in this command; this command is for impact discovery only. diff --git a/.kilo/command/perf-review.md b/.kilo/command/perf-review.md deleted file mode 100644 index 8ef1865..0000000 --- a/.kilo/command/perf-review.md +++ /dev/null @@ -1,149 +0,0 @@ ---- -description: Benchmark-driven PR performance review versus target branch ---- - -# Perf Review Workflow - -You are performing a performance review for the current PR branch. - -Non-negotiable requirements: -1. Benchmark timing plus profiling data is the highest-priority judgment tool. -2. Compare source branch versus target branch and report relevant benchmark metric changes. -3. Provide analysis and a final verdict: does the PR improve performance or not. - -## Inputs - -- Optional argument `--target `: target branch override. -- Optional argument `--filter `: benchmark filter regex. -- Optional argument `--no-counters`: disable hardware-counter collection. - -If arguments are omitted: -- Default target branch to PR base branch from `gh pr view --json baseRefName` when available. -- Fall back target branch to `main`. - -Filter handling: -- If `--filter` is provided, pass it through. -- Else use the filter produced by `benchmarks-affected` through `benchmarks-compare-revisions`. -- If no filter can be derived, run conservative full-binary compare for impacted binaries. - -## Step 1 - Resolve branches and hashes - -1. Resolve contender from current checkout (`HEAD`) and compute short hash. -2. Resolve baseline branch using precedence: `--target` -> PR base from `gh pr view --json baseRefName` -> `main`. -3. Resolve baseline short hash. -4. Print branch/hash mapping before benchmark execution. - -## Step 2 - Run timing and hardware-counter comparison via skill (single source of truth) - -Use `benchmarks-compare-revisions` as the single source of truth for revision builds, benchmark scope, compare.py flow, retry policy, and guardrails. - -Pass-through inputs: -- Baseline ref/hash from Step 1. -- Contender ref/hash from Step 1. -- Optional `--filter` override. -- Counter mode: default on (`COLLECT_COUNTERS=1`) on Linux, disabled when `--no-counters` is provided. - -Consume outputs from `benchmarks-compare-revisions`: -- Baseline and contender benchmark JSON artifacts. -- compare.py output per binary. -- Effective filter used. -- Scope metadata from `benchmarks-affected` (`affected_benchmark_targets`, `affected_benchmarks`) when available. -- `counters_available` status and, when unavailable, explicit reason. -- Baseline and contender counter JSON artifacts (when available). -- Derived counter metrics per benchmark (IPC, cache miss rate, branch mispredict rate). -- Counter anomaly list and ready-to-embed counter summary table. - -Execution guardrails: -- Run benchmarks sequentially. -- No background jobs (`nohup`, `&`). -- Use Release timing builds only. -- If timing comparison fails, return blocked verdict with exact failure points. - -## Step 3 - Consume delegated hardware-counter outputs - -Hardware-counter collection is delegated to `benchmarks-compare-revisions`. - -Pass-through inputs: -- `COLLECT_COUNTERS=1` by default on Linux (unless `--no-counters` is provided). -- Same baseline/contender refs and effective filter used in Step 2. - -Consume outputs: -- Counter preflight result. -- Counter JSON artifacts for both revisions. -- Derived metrics (IPC, cache miss rate, branch mispredict rate). -- Anomaly list and counter summary table for report embedding. - -If counters are unavailable (`counters_available=false`), continue with timing-only review and explicitly mark profiling as unavailable in the report. - -## Step 4 - Analyze timing and counter data - -Timing classification per benchmark entry: -- Improvement: time delta < -5% -- Regression: time delta > +5% -- Neutral: between -5% and +5% - -Aggregate per binary: -- Number of improvements/regressions/neutral -- Net average percentage change -- Largest regression and largest improvement - -Counter correlation: -- Use skill-provided hardware counter summary and anomaly list to explain major timing changes. -- Do not recompute derived counter metrics in this command. - -Judgment priority: -- Base verdict primarily on benchmark timing comparison. -- Use counter data as explanatory evidence and confidence signal. - -Noise-control expectations: -- Include at least one control benchmark family expected to be unaffected by the code change. -- Treat isolated swings without pattern as noise unless reproduced across related sizes/fill ratios. - -## Step 5 - Produce final markdown report - -Return a structured markdown report with this shape: - -```markdown -## Performance Review: vs - -### Configuration -- Baseline: () -- Contender: () -- Platform: -- Benchmarks run: -- Filter: -- Hardware counters: available / unavailable - -### Timing Summary -| Binary | Improvements | Regressions | Neutral | Net Change | -|---|---:|---:|---:|---:| -| ... | N | N | N | +/-X% | - -### Detailed Timing Results - - -### Hardware Counter Profile (if available) -| Benchmark | IPC (base->new) | Cache Miss Rate (base->new) | Branch Mispredict (base->new) | -|---|---:|---:|---:| -| ... | X.XX -> Y.YY | A.A% -> B.B% | C.C% -> D.D% | - -### Key Findings -- -- - -### Verdict -**[IMPROVES PERFORMANCE | REGRESSES PERFORMANCE | NO SIGNIFICANT CHANGE]** - -<1-2 sentence justification grounded in benchmark metrics, with profiling context if available> -``` - -Verdict rules: -- `IMPROVES PERFORMANCE`: improvements outnumber regressions, no severe regression (>10%), and net average change is favorable. -- `REGRESSES PERFORMANCE`: any severe regression (>10%) or regressions dominate with net unfavorable average. -- `NO SIGNIFICANT CHANGE`: mostly neutral changes or mixed results that approximately cancel out. - -## Failure Handling - -- If required builds fail or timing comparison cannot run, output a blocked review with exact failure points and no misleading verdict. -- If only profiling fails (`counters_available=false` from delegated skill output), continue with timing-based verdict and explicitly list profiling limitation. -- If JSON output is invalid/truncated, discard it and rerun that benchmark command once with tighter filter and explicit output redirection. diff --git a/.kilo/command/ping.md b/.kilo/command/ping.md deleted file mode 100644 index b5edaf7..0000000 --- a/.kilo/command/ping.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -description: Test command that replies with pong ---- - -Respond with exactly `pong`. -Do not add any other words. -Do not add quotes or punctuation. diff --git a/.kilo/skills/benchmarks-affected/SKILL.md b/.kilo/skills/benchmarks-affected/SKILL.md deleted file mode 100644 index e886c92..0000000 --- a/.kilo/skills/benchmarks-affected/SKILL.md +++ /dev/null @@ -1,77 +0,0 @@ ---- -name: benchmarks-affected -description: Analyze current branch versus a baseline and extract affected benchmark targets and benchmark functions using compile_commands and clang AST. ---- - -# Benchmarks Affected Skill - -Use this skill to identify exactly which benchmark binaries and benchmark functions are affected by code changes on the current branch. - -It implements a two-stage workflow: - -1. `compile_commands.json` analysis to determine affected compile targets. -2. Clang AST analysis to determine affected benchmark functions. - -## Goal - -Given `HEAD` and a baseline branch (default `main`), produce: - -- Changed files. -- Affected targets (with emphasis on benchmark targets). -- Exact benchmark functions impacted by the changes. -- A ready-to-use Google Benchmark filter regex. - -## Prerequisites - -1. Build tree with benchmarks enabled and compile database exported: - -```bash -BUILD_SUFFIX=local -cmake -B build/benchmarks-all_${BUILD_SUFFIX} \ - -DCMAKE_BUILD_TYPE=Release \ - -DPIXIE_BENCHMARKS=ON \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -cmake --build build/benchmarks-all_${BUILD_SUFFIX} --config Release -j -``` - -2. `clang++` must be available on `PATH` (used for AST dump). - -## Run - -```bash -python3 .kilo/skills/benchmarks-affected/analyze_benchmarks_affected.py \ - --baseline main \ - --compile-commands build/benchmarks-all_local/compile_commands.json \ - --format json -``` - -If `--compile-commands` is omitted, the script auto-selects the most recently modified `build/**/compile_commands.json`. -Working tree changes are included by default. Use `--no-include-working-tree` to restrict analysis to `...HEAD` only. - -## Output - -The analyzer reports: - -- `affected_targets`: impacted CMake targets inferred from compile dependency analysis. -- `affected_benchmark_targets`: subset of benchmark binaries impacted. -- `affected_benchmarks`: precise benchmark function names from AST-level call analysis. -- `suggested_filter_regex`: regex to pass as `--benchmark_filter`. - -## How to Use Findings - -1. Build only impacted benchmark binaries where feasible. -2. Run benchmark binaries with the suggested filter: - -```bash -FILTER='^(BM_RankNonInterleaved|BM_SelectNonInterleaved)(/|$)' -build/benchmarks-all_local/benchmarks --benchmark_filter="${FILTER}" -``` - -3. If impact mapping is broad/uncertain, run full binary for selected benchmark target(s). - -## Guardrails - -1. Keep baseline comparison at merge-base style diff: `...HEAD`. -2. Use Release binaries for timing runs. -3. If AST parse fails for a TU, still trust compile target impact and mark benchmark-function scope as partial. -4. If benchmark infra (`CMakeLists.txt`, benchmark source layout) changed, fall back to conservative benchmark selection. diff --git a/.kilo/skills/benchmarks-affected/analyze_benchmarks_affected.py b/.kilo/skills/benchmarks-affected/analyze_benchmarks_affected.py deleted file mode 100644 index e858d45..0000000 --- a/.kilo/skills/benchmarks-affected/analyze_benchmarks_affected.py +++ /dev/null @@ -1,1138 +0,0 @@ -#!/usr/bin/env python3 - -from __future__ import annotations - -import argparse -import concurrent.futures -import json -import os -import re -import shlex -import shutil -import subprocess -import sys -from collections import defaultdict -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any - - -def is_project_source(source: Path, repo_root: Path) -> bool: - """Exclude third-party deps and generated build files.""" - try: - rel = source.relative_to(repo_root) - except ValueError: - return False - rel_text = rel.as_posix() - if rel_text.startswith("build/") or "_deps/" in rel_text: - return False - return True - - -KNOWN_BENCHMARK_TARGETS = { - "benchmarks", - "bench_rmm", - "bench_rmm_sdsl", - "louds_tree_benchmarks", - "alignment_comparison", -} - -HEADER_EXTENSIONS = { - ".h", - ".hh", - ".hpp", - ".hxx", - ".inc", - ".ipp", - ".tcc", -} - -BUILD_INFRA_FILES = { - "CMakeLists.txt", - "CMakePresets.json", -} - -DIFF_HUNK_RE = re.compile(r"^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@") - -CPP_FUNCTION_START_RE = re.compile( - r"^\s*" - r"(?:template\s*<[^>]*>\s*)?" - r"(?:(?:inline|constexpr|consteval|constinit|static|friend|virtual|explicit)\s+)*" - r"[A-Za-z_~][\w:<>,\s\*&\[\]]*\s+" - r"([~A-Za-z_][A-Za-z0-9_]*)\s*" - r"\([^;{}]*\)\s*" - r"(?:const\s*)?" - r"(?:noexcept\s*)?" - r"(?:->\s*[^\{]+)?\{" -) - - -@dataclass -class CompileCommandEntry: - directory: Path - source: Path - arguments: list[str] - output: Path | None - target: str | None - dependencies: set[Path] = field(default_factory=set) - dep_error: str | None = None - - -@dataclass -class AstImpactResult: - benchmark_names: set[str] = field(default_factory=set) - affected_names: set[str] = field(default_factory=set) - ast_error: str | None = None - - -def run_command( - args: list[str], - cwd: Path, - check: bool = True, - timeout: float | None = 60.0, -) -> subprocess.CompletedProcess[str]: - return subprocess.run( - args, - cwd=str(cwd), - text=True, - capture_output=True, - check=check, - timeout=timeout, - ) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description=( - "Analyze benchmark impact between baseline and HEAD via " - "compile_commands dependency mapping and clang AST analysis." - ) - ) - parser.add_argument( - "--baseline", - default="main", - help="Baseline ref used as ...HEAD (default: main).", - ) - parser.add_argument( - "--head", - default="HEAD", - help="Contender ref (default: HEAD).", - ) - parser.add_argument( - "--compile-commands", - default=None, - help=( - "Path to compile_commands.json. If omitted, auto-discovers most " - "recent build/**/compile_commands.json." - ), - ) - parser.add_argument( - "--clangxx", - default=None, - help="clang++ executable for AST dump (auto-detected if omitted).", - ) - parser.add_argument( - "--format", - choices=["text", "json"], - default="text", - help="Output format (default: text).", - ) - parser.add_argument( - "--include-working-tree", - dest="include_working_tree", - action="store_true", - default=True, - help=( - "Include local unstaged/staged changes in changed-files set, " - "in addition to ... (default: enabled)." - ), - ) - parser.add_argument( - "--no-include-working-tree", - dest="include_working_tree", - action="store_false", - help="Disable working-tree inclusion and only analyze ....", - ) - return parser.parse_args() - - -def git_repo_root() -> Path: - proc = run_command(["git", "rev-parse", "--show-toplevel"], cwd=Path.cwd()) - return Path(proc.stdout.strip()).resolve() - - -def resolve_compile_commands(repo_root: Path, explicit_path: str | None) -> Path: - if explicit_path: - compile_path = Path(explicit_path) - if not compile_path.is_absolute(): - compile_path = (repo_root / compile_path).resolve() - if not compile_path.exists(): - raise FileNotFoundError(f"compile_commands.json not found: {compile_path}") - return compile_path - - candidates = sorted( - repo_root.glob("build/**/compile_commands.json"), - key=lambda path: path.stat().st_mtime, - reverse=True, - ) - if not candidates: - raise FileNotFoundError( - "No compile_commands.json found under build/**. " - "Configure with -DCMAKE_EXPORT_COMPILE_COMMANDS=ON first." - ) - return candidates[0].resolve() - - -def load_compile_commands( - compile_commands_path: Path, - repo_root: Path, -) -> list[CompileCommandEntry]: - entries: list[CompileCommandEntry] = [] - data = json.loads(compile_commands_path.read_text(encoding="utf-8")) - for raw_entry in data: - directory = Path(raw_entry["directory"]).resolve() - - raw_source = Path(raw_entry["file"]) - if raw_source.is_absolute(): - source = raw_source.resolve() - else: - source = (directory / raw_source).resolve() - - if not is_project_source(source, repo_root): - continue - - if "arguments" in raw_entry: - arguments = [str(arg) for arg in raw_entry["arguments"]] - else: - arguments = shlex.split(raw_entry["command"]) - - output = infer_output_path(arguments, directory) - target = infer_cmake_target_from_output(output) - - entries.append( - CompileCommandEntry( - directory=directory, - source=source, - arguments=arguments, - output=output, - target=target, - ) - ) - return entries - - -def infer_output_path(arguments: list[str], directory: Path) -> Path | None: - output_token: str | None = None - for idx, arg in enumerate(arguments): - if arg == "-o" and idx + 1 < len(arguments): - output_token = arguments[idx + 1] - elif arg.startswith("-o") and len(arg) > 2: - output_token = arg[2:] - elif arg.startswith("/Fo") and len(arg) > 3: - output_token = arg[3:] - - if output_token is None: - return None - - out_path = Path(output_token) - if out_path.is_absolute(): - return out_path.resolve() - return (directory / out_path).resolve() - - -def infer_cmake_target_from_output(output: Path | None) -> str | None: - if output is None: - return None - parts = output.parts - for index, part in enumerate(parts): - if part == "CMakeFiles" and index + 1 < len(parts): - target_part = parts[index + 1] - if target_part.endswith(".dir"): - return target_part[: -len(".dir")] - return target_part - return None - - -def git_changed_files(repo_root: Path, baseline: str, head: str) -> set[Path]: - diff_range = f"{baseline}...{head}" - proc = run_command(["git", "diff", "--name-only", diff_range], cwd=repo_root) - changed_files: set[Path] = set() - for line in proc.stdout.splitlines(): - line = line.strip() - if not line: - continue - changed_files.add((repo_root / line).resolve()) - return changed_files - - -def git_working_tree_changed_files(repo_root: Path) -> set[Path]: - changed_files: set[Path] = set() - commands = [ - ["git", "diff", "--name-only"], - ["git", "diff", "--name-only", "--cached"], - ] - for cmd in commands: - proc = run_command(cmd, cwd=repo_root) - for line in proc.stdout.splitlines(): - line = line.strip() - if not line: - continue - changed_files.add((repo_root / line).resolve()) - return changed_files - - -def parse_changed_lines_from_diff_text( - diff_text: str, - repo_root: Path, -) -> dict[Path, set[int]]: - changed_lines: dict[Path, set[int]] = defaultdict(set) - - current_file: Path | None = None - in_hunk = False - new_line = 0 - - for raw_line in diff_text.splitlines(): - if raw_line.startswith("+++ "): - file_token = raw_line[4:].strip() - if file_token == "/dev/null": - current_file = None - in_hunk = False - continue - if file_token.startswith("b/"): - file_token = file_token[2:] - current_file = (repo_root / file_token).resolve() - in_hunk = False - continue - - hunk_match = DIFF_HUNK_RE.match(raw_line) - if hunk_match: - in_hunk = current_file is not None - new_line = int(hunk_match.group(1)) - continue - - if not in_hunk or current_file is None: - continue - - if raw_line.startswith("+") and not raw_line.startswith("+++"): - changed_lines[current_file].add(new_line) - new_line += 1 - continue - - if raw_line.startswith("-") and not raw_line.startswith("---"): - continue - - if raw_line.startswith(" "): - new_line += 1 - continue - - return changed_lines - - -def git_changed_line_map( - repo_root: Path, - baseline: str, - head: str, - include_working_tree: bool, -) -> dict[Path, set[int]]: - changed_lines: dict[Path, set[int]] = defaultdict(set) - - proc = run_command( - ["git", "diff", "--unified=0", f"{baseline}...{head}"], - cwd=repo_root, - ) - baseline_map = parse_changed_lines_from_diff_text(proc.stdout, repo_root) - for path, lines in baseline_map.items(): - changed_lines[path].update(lines) - - if include_working_tree: - for cmd in ( - ["git", "diff", "--unified=0"], - ["git", "diff", "--cached", "--unified=0"], - ): - wt_proc = run_command(cmd, cwd=repo_root) - wt_map = parse_changed_lines_from_diff_text(wt_proc.stdout, repo_root) - for path, lines in wt_map.items(): - changed_lines[path].update(lines) - - return changed_lines - - -def extract_changed_symbol_names_from_file( - file_path: Path, - changed_lines: set[int], -) -> set[str]: - if not changed_lines or not file_path.exists(): - return set() - - lines = file_path.read_text(encoding="utf-8", errors="replace").splitlines() - symbols: set[str] = set() - - line_index = 1 - max_line = len(lines) - while line_index <= max_line: - line = lines[line_index - 1] - match = CPP_FUNCTION_START_RE.match(line) - if not match: - line_index += 1 - continue - - symbol_name = match.group(1) - start_line = line_index - brace_depth = line.count("{") - line.count("}") - end_line = start_line - - while brace_depth > 0 and end_line < max_line: - end_line += 1 - body_line = lines[end_line - 1] - brace_depth += body_line.count("{") - body_line.count("}") - - if any(start_line <= line_no <= end_line for line_no in changed_lines): - symbols.add(symbol_name) - - line_index = end_line + 1 - - return symbols - - -def collect_changed_symbol_names( - changed_line_map: dict[Path, set[int]], -) -> set[str]: - symbol_names: set[str] = set() - for file_path, changed_lines in changed_line_map.items(): - symbol_names.update( - extract_changed_symbol_names_from_file(file_path, changed_lines) - ) - return symbol_names - - -def clean_command_for_dependency_scan(arguments: list[str]) -> list[str]: - cleaned: list[str] = [] - skip_next = False - flags_with_value = { - "-o", - "-MF", - "-MT", - "-MQ", - "-MJ", - "-Xclang", - } - standalone_drop = { - "-c", - "-MD", - "-MMD", - "-MP", - "-MM", - "-M", - "-E", - "-S", - } - - index = 0 - while index < len(arguments): - arg = arguments[index] - if skip_next: - skip_next = False - index += 1 - continue - - if arg in flags_with_value: - skip_next = True - index += 1 - continue - if arg in standalone_drop: - index += 1 - continue - if arg.startswith("-o") and len(arg) > 2: - index += 1 - continue - if arg.startswith("-MF") and len(arg) > 3: - index += 1 - continue - if arg.startswith("-MT") and len(arg) > 3: - index += 1 - continue - if arg.startswith("-MQ") and len(arg) > 3: - index += 1 - continue - if arg.startswith("-MJ") and len(arg) > 3: - index += 1 - continue - - cleaned.append(arg) - index += 1 - - return cleaned - - -def parse_makefile_dependencies(stdout_text: str) -> list[str]: - flattened = stdout_text.replace("\\\n", " ").replace("\n", " ") - if ":" not in flattened: - return [] - dep_payload = flattened.split(":", 1)[1].strip() - if not dep_payload: - return [] - return shlex.split(dep_payload) - - -def compute_tu_dependencies(entry: CompileCommandEntry) -> None: - dep_cmd = clean_command_for_dependency_scan(entry.arguments) - if not dep_cmd: - entry.dep_error = "Empty compile command after sanitization" - entry.dependencies = {entry.source} - return - - dep_cmd.extend(["-MM", "-MF", "-", "-MT", "__pixie_tu__"]) - source_arg = str(entry.source) - if source_arg not in dep_cmd: - dep_cmd.append(source_arg) - - try: - proc = run_command(dep_cmd, cwd=entry.directory, check=False) - except FileNotFoundError as exc: - entry.dep_error = str(exc) - entry.dependencies = {entry.source} - return - - dependencies: set[Path] = {entry.source} - if proc.returncode != 0: - stderr = proc.stderr.strip() - entry.dep_error = ( - stderr if stderr else f"Dependency scan failed ({proc.returncode})" - ) - entry.dependencies = dependencies - return - - for dep in parse_makefile_dependencies(proc.stdout): - dep_path = Path(dep) - resolved = ( - dep_path.resolve() - if dep_path.is_absolute() - else (entry.directory / dep_path).resolve() - ) - dependencies.add(resolved) - - entry.dependencies = dependencies - - -def is_build_infra_change(repo_root: Path, changed: set[Path]) -> bool: - for path in changed: - if path.name in BUILD_INFRA_FILES: - return True - try: - rel = path.relative_to(repo_root) - except ValueError: - continue - rel_text = rel.as_posix() - if rel_text.startswith("cmake/"): - return True - return False - - -def identify_benchmark_targets( - entries: list[CompileCommandEntry], repo_root: Path -) -> set[str]: - benchmark_targets: set[str] = set() - targets_present = {entry.target for entry in entries if entry.target} - for entry in entries: - if entry.target is None: - continue - try: - rel = entry.source.relative_to(repo_root) - rel_text = rel.as_posix() - except ValueError: - rel_text = entry.source.as_posix() - - if rel_text.startswith("src/benchmarks/"): - benchmark_targets.add(entry.target) - - benchmark_targets.update(targets_present.intersection(KNOWN_BENCHMARK_TARGETS)) - return benchmark_targets - - -def is_benchmark_source(source: Path, repo_root: Path) -> bool: - try: - rel_text = source.relative_to(repo_root).as_posix() - except ValueError: - return False - return rel_text.startswith("src/benchmarks/") - - -def dedupe_entries_by_target_source( - entries: list[CompileCommandEntry], -) -> list[CompileCommandEntry]: - deduped: list[CompileCommandEntry] = [] - seen: set[tuple[str | None, Path]] = set() - for entry in entries: - key = (entry.target, entry.source) - if key in seen: - continue - seen.add(key) - deduped.append(entry) - return deduped - - -def discover_clangxx(explicit: str | None) -> str: - if explicit: - return explicit - - candidates = [ - "clang++", - "clang++-19", - "clang++-18", - "clang++-17", - "clang++-16", - ] - for candidate in candidates: - resolved = shutil.which(candidate) - if resolved: - return resolved - raise FileNotFoundError( - "clang++ was not found on PATH. Provide --clangxx to select a clang compiler." - ) - - -def clean_command_for_ast(arguments: list[str], clangxx: str) -> list[str]: - cleaned = clean_command_for_dependency_scan(arguments) - if not cleaned: - return [] - cleaned[0] = clangxx - cleaned.extend(["-Xclang", "-ast-dump=json", "-fsyntax-only"]) - return cleaned - - -def normalize_path_candidate(path_text: str | None, working_dir: Path) -> Path | None: - if not path_text: - return None - path = Path(path_text) - if path.is_absolute(): - return path.resolve() - return (working_dir / path).resolve() - - -def file_from_loc(loc: dict[str, Any] | None, working_dir: Path) -> Path | None: - if not isinstance(loc, dict): - return None - if "file" in loc: - return normalize_path_candidate(str(loc["file"]), working_dir) - for nested_key in ("spellingLoc", "expansionLoc", "includedFrom"): - nested_loc = loc.get(nested_key) - if isinstance(nested_loc, dict): - resolved = file_from_loc(nested_loc, working_dir) - if resolved is not None: - return resolved - return None - - -def iter_ast_nodes(node: Any): - if isinstance(node, dict): - yield node - inner = node.get("inner", []) - if isinstance(inner, list): - for child in inner: - yield from iter_ast_nodes(child) - elif isinstance(node, list): - for item in node: - yield from iter_ast_nodes(item) - - -def referenced_decl_file(node: dict[str, Any], working_dir: Path) -> Path | None: - referenced = node.get("referencedDecl") - if not isinstance(referenced, dict): - return None - return file_from_loc(referenced.get("loc"), working_dir) - - -def node_references_changed_symbol( - node: dict[str, Any], - changed_symbol_names: set[str], -) -> bool: - if not changed_symbol_names: - return False - - for subnode in iter_ast_nodes(node): - if not isinstance(subnode, dict): - continue - - kind = subnode.get("kind") - if kind == "MemberExpr": - member_name = subnode.get("name") - if isinstance(member_name, str) and member_name in changed_symbol_names: - return True - - if kind == "DeclRefExpr": - ref_decl = subnode.get("referencedDecl") - if not isinstance(ref_decl, dict): - continue - ref_name = ref_decl.get("name") - if isinstance(ref_name, str) and ref_name in changed_symbol_names: - return True - - return False - - -def call_expr_callee_name(call_expr: dict[str, Any]) -> str | None: - for node in iter_ast_nodes(call_expr): - if not isinstance(node, dict): - continue - if node.get("kind") != "DeclRefExpr": - continue - referenced = node.get("referencedDecl") - if isinstance(referenced, dict) and isinstance(referenced.get("name"), str): - return referenced["name"] - return None - - -def string_literals_in_node(node: dict[str, Any]) -> list[str]: - values: list[str] = [] - for cur in iter_ast_nodes(node): - if not isinstance(cur, dict): - continue - if cur.get("kind") != "StringLiteral": - continue - value = cur.get("value") - if isinstance(value, str): - if len(value) >= 2 and value[0] == '"' and value[-1] == '"': - value = value[1:-1] - values.append(value) - return values - - -def benchmark_names_from_source(source: Path) -> set[str]: - names: set[str] = set() - if not source.exists(): - return names - text = source.read_text(encoding="utf-8", errors="replace") - for match in re.finditer(r"BENCHMARK\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)", text): - names.add(match.group(1)) - for match in re.finditer(r"register_op\(\s*\"([^\"]+)\"", text): - names.add(match.group(1)) - return names - - -def ast_analyze_entry( - entry: CompileCommandEntry, - changed_files: set[Path], - changed_symbol_names: set[str], - clangxx: str, -) -> AstImpactResult: - result = AstImpactResult() - - ast_cmd = clean_command_for_ast(entry.arguments, clangxx) - if not ast_cmd: - result.ast_error = "Failed to build AST command" - return result - - try: - proc = run_command(ast_cmd, cwd=entry.directory, check=False) - except FileNotFoundError as exc: - result.ast_error = str(exc) - return result - - if proc.returncode != 0: - stderr = proc.stderr.strip() - result.ast_error = ( - stderr if stderr else f"AST command failed ({proc.returncode})" - ) - return result - - try: - ast_root = json.loads(proc.stdout) - except json.JSONDecodeError as exc: - result.ast_error = f"Invalid AST JSON: {exc}" - return result - - function_callees: dict[str, set[str]] = defaultdict(set) - direct_impacted_functions: set[str] = set() - dynamic_benchmarks_by_function: dict[str, set[str]] = defaultdict(set) - - for node in iter_ast_nodes(ast_root): - if not isinstance(node, dict): - continue - - if node.get("kind") not in {"FunctionDecl", "CXXMethodDecl"}: - continue - - function_name = node.get("name") - if not isinstance(function_name, str) or not function_name: - continue - - if function_name.startswith("BM_"): - result.benchmark_names.add(function_name) - - function_callees.setdefault(function_name, set()) - - function_loc = file_from_loc(node.get("loc"), entry.directory) - is_directly_impacted = function_loc in changed_files - if not is_directly_impacted: - is_directly_impacted = node_references_changed_symbol( - node, changed_symbol_names - ) - - for subnode in iter_ast_nodes(node): - if not isinstance(subnode, dict): - continue - - sub_kind = subnode.get("kind") - if sub_kind in {"CallExpr", "CXXMemberCallExpr", "CXXOperatorCallExpr"}: - callee = call_expr_callee_name(subnode) - if callee: - function_callees[function_name].add(callee) - - if callee == "register_op": - literal_values = string_literals_in_node(subnode) - if literal_values: - dynamic_benchmarks_by_function[function_name].add( - literal_values[0] - ) - - if not is_directly_impacted: - ref_file = referenced_decl_file(subnode, entry.directory) - if ref_file in changed_files: - is_directly_impacted = True - - if is_directly_impacted: - direct_impacted_functions.add(function_name) - - # Reverse call-graph propagation: if a function is directly impacted, - # every caller in this TU is impacted as well (fixed-point DFS/BFS). - callers_of: dict[str, set[str]] = defaultdict(set) - for caller, callees in function_callees.items(): - for callee in callees: - callers_of[callee].add(caller) - - impacted_functions = set(direct_impacted_functions) - stack = list(direct_impacted_functions) - while stack: - callee_name = stack.pop() - for caller_name in callers_of.get(callee_name, set()): - if caller_name in impacted_functions: - continue - impacted_functions.add(caller_name) - stack.append(caller_name) - - for function_name in impacted_functions: - if function_name.startswith("BM_"): - result.affected_names.add(function_name) - - for function_name, names in dynamic_benchmarks_by_function.items(): - result.benchmark_names.update(names) - if function_name in impacted_functions: - result.affected_names.update(names) - - return result - - -def regex_for_benchmarks(names: set[str]) -> str | None: - if not names: - return None - ordered = sorted(names) - body = "|".join(re.escape(name) for name in ordered) - return rf"^({body})(/|$)" - - -def relpath_or_abs(path: Path, root: Path) -> str: - try: - return path.relative_to(root).as_posix() - except ValueError: - return path.as_posix() - - -def main() -> int: - cli = parse_args() - - try: - repo_root = git_repo_root() - changed_files = git_changed_files(repo_root, cli.baseline, cli.head) - if cli.include_working_tree: - changed_files.update(git_working_tree_changed_files(repo_root)) - changed_line_map = git_changed_line_map( - repo_root, - cli.baseline, - cli.head, - cli.include_working_tree, - ) - changed_symbol_names = collect_changed_symbol_names(changed_line_map) - compile_commands_path = resolve_compile_commands( - repo_root, cli.compile_commands - ) - entries = load_compile_commands(compile_commands_path, repo_root) - except FileNotFoundError as exc: - print(f"error: {exc}", file=sys.stderr) - return 2 - except subprocess.CalledProcessError as exc: - stderr = (exc.stderr or "").strip() - if stderr: - print(f"error: {stderr}", file=sys.stderr) - else: - print(f"error: command failed: {' '.join(exc.cmd)}", file=sys.stderr) - return 2 - - target_to_entries: dict[str, list[CompileCommandEntry]] = defaultdict(list) - source_to_entries: dict[Path, list[CompileCommandEntry]] = defaultdict(list) - for entry in entries: - source_to_entries[entry.source].append(entry) - if entry.target: - target_to_entries[entry.target].append(entry) - - benchmark_targets = identify_benchmark_targets(entries, repo_root) - all_targets = {entry.target for entry in entries if entry.target} - benchmark_entries = dedupe_entries_by_target_source( - [entry for entry in entries if entry.target in benchmark_targets] - ) - - infra_change = is_build_infra_change(repo_root, changed_files) - relevant_changed_files = { - path - for path in changed_files - if is_project_source(path, repo_root) - or path.name in BUILD_INFRA_FILES - or relpath_or_abs(path, repo_root).startswith("cmake/") - } - has_header_changes = any( - path.suffix.lower() in HEADER_EXTENSIONS for path in relevant_changed_files - ) - benchmark_source_extensions = {".c", ".cc", ".cpp", ".cxx"} - only_benchmark_source_changes = bool(relevant_changed_files) and all( - is_benchmark_source(path, repo_root) - and path.suffix.lower() in benchmark_source_extensions - for path in relevant_changed_files - ) - - directly_affected_targets: set[str] = set() - for changed_path in changed_files: - for entry in source_to_entries.get(changed_path, []): - if entry.target: - directly_affected_targets.add(entry.target) - - dependency_scan_entries: list[CompileCommandEntry] = [] - if not infra_change and not only_benchmark_source_changes: - if has_header_changes: - dependency_scan_entries = dedupe_entries_by_target_source(entries) - else: - dependency_scan_entries = benchmark_entries - - if dependency_scan_entries: - with concurrent.futures.ThreadPoolExecutor( - max_workers=min(8, (os.cpu_count() or 4)) - ) as pool: - list(pool.map(compute_tu_dependencies, dependency_scan_entries)) - - affected_targets: set[str] = set(directly_affected_targets) - for entry in dependency_scan_entries: - has_changed_dependency = any(dep in changed_files for dep in entry.dependencies) - if has_changed_dependency and entry.target: - affected_targets.add(entry.target) - - if infra_change: - affected_targets.update(all_targets) - - dependency_impacted_benchmark_targets = affected_targets.intersection( - benchmark_targets - ) - impacted_benchmark_entries = [ - entry - for entry in benchmark_entries - if entry.target in dependency_impacted_benchmark_targets - ] - - ast_errors: dict[str, str] = {} - benchmark_target_to_names: dict[str, set[str]] = defaultdict(set) - benchmark_target_to_affected: dict[str, set[str]] = defaultdict(set) - warnings: list[str] = [] - ast_fallback_used = False - ast_entries_scanned = 0 - - if impacted_benchmark_entries: - try: - clangxx = discover_clangxx(cli.clangxx) - except FileNotFoundError as exc: - clangxx = "" - warnings.append(str(exc)) - - if not clangxx: - ast_fallback_used = True - for entry in impacted_benchmark_entries: - target_name = entry.target or "" - fallback_names = benchmark_names_from_source(entry.source) - benchmark_target_to_names[target_name].update(fallback_names) - benchmark_target_to_affected[target_name].update(fallback_names) - else: - max_ast_workers = min(2, (os.cpu_count() or 2)) - with concurrent.futures.ThreadPoolExecutor( - max_workers=max_ast_workers - ) as pool: - futures = { - pool.submit( - ast_analyze_entry, - entry, - changed_files, - changed_symbol_names, - clangxx, - ): entry - for entry in impacted_benchmark_entries - } - ast_entries_scanned = len(futures) - for future in concurrent.futures.as_completed(futures): - entry = futures[future] - target_name = entry.target or "" - source_path = entry.source - source_is_changed = source_path in changed_files - - try: - ast_result = future.result(timeout=120) - except Exception as exc: - ast_result = AstImpactResult( - ast_error=f"AST worker failed: {exc}" - ) - - if ast_result.ast_error: - ast_errors[relpath_or_abs(source_path, repo_root)] = ( - ast_result.ast_error - ) - - benchmark_names = ast_result.benchmark_names - if not benchmark_names: - benchmark_names = benchmark_names_from_source(source_path) - benchmark_target_to_names[target_name].update(benchmark_names) - - if ast_result.affected_names: - benchmark_target_to_affected[target_name].update( - ast_result.affected_names - ) - elif source_is_changed or ast_result.ast_error: - benchmark_target_to_affected[target_name].update( - benchmark_names - ) - if benchmark_names: - ast_fallback_used = True - - if infra_change and benchmark_targets: - for target_name in sorted(benchmark_targets): - for entry in target_to_entries.get(target_name, []): - names = benchmark_names_from_source(entry.source) - benchmark_target_to_names[target_name].update(names) - benchmark_target_to_affected[target_name].update(names) - - if infra_change: - affected_benchmark_targets = sorted(benchmark_targets) - else: - affected_benchmark_targets = sorted( - target for target, names in benchmark_target_to_affected.items() if names - ) - - all_affected_benchmarks: set[str] = set() - for names in benchmark_target_to_affected.values(): - all_affected_benchmarks.update(names) - - dep_scan_failures = { - relpath_or_abs(entry.source, repo_root): entry.dep_error - for entry in dependency_scan_entries - if entry.dep_error - } - - scope_mode = "normal" - if infra_change: - scope_mode = "infra_fallback" - elif ast_fallback_used: - scope_mode = "ast_fallback" - - report: dict[str, Any] = { - "baseline": cli.baseline, - "head": cli.head, - "include_working_tree": cli.include_working_tree, - "changed_symbols": sorted(changed_symbol_names), - "compile_commands": relpath_or_abs(compile_commands_path, repo_root), - "changed_files": sorted( - relpath_or_abs(path, repo_root) for path in changed_files - ), - "affected_targets": sorted(affected_targets), - "affected_benchmark_targets": affected_benchmark_targets, - "affected_benchmarks": { - target: sorted(names) - for target, names in sorted(benchmark_target_to_affected.items()) - if names - }, - "suggested_filter_regex": regex_for_benchmarks(all_affected_benchmarks), - "dependency_entries_scanned": len(dependency_scan_entries), - "benchmark_entries_scanned": len(benchmark_entries), - "ast_entries_scanned": ast_entries_scanned, - "scope_mode": scope_mode, - "dependency_scan_failures": dep_scan_failures, - "ast_failures": ast_errors, - "warnings": warnings, - } - - if cli.format == "json": - json.dump(report, sys.stdout, indent=2) - sys.stdout.write("\n") - return 0 - - print(f"Baseline: {cli.baseline}") - print(f"Head: {cli.head}") - print(f"Compile commands: {report['compile_commands']}") - print(f"Scope mode: {report['scope_mode']}") - print( - "Scan counts: " - f"dependency={report['dependency_entries_scanned']}, " - f"benchmark={report['benchmark_entries_scanned']}, " - f"ast={report['ast_entries_scanned']}" - ) - print("") - - print(f"Changed files ({len(report['changed_files'])}):") - for item in report["changed_files"]: - print(f"- {item}") - if not report["changed_files"]: - print("- none") - - print("") - print(f"Affected targets ({len(report['affected_targets'])}):") - for item in report["affected_targets"]: - print(f"- {item}") - if not report["affected_targets"]: - print("- none") - - print("") - print(f"Affected benchmark targets ({len(report['affected_benchmark_targets'])}):") - for item in report["affected_benchmark_targets"]: - print(f"- {item}") - if not report["affected_benchmark_targets"]: - print("- none") - - print("") - print("Affected benchmark functions:") - if report["affected_benchmarks"]: - for target, names in report["affected_benchmarks"].items(): - print(f"- {target}:") - for name in names: - print(f" - {name}") - else: - print("- none") - - print("") - print("Suggested --benchmark_filter regex:") - print(report["suggested_filter_regex"] or "none") - - if dep_scan_failures: - print("") - print("Dependency scan failures:") - for source, error in dep_scan_failures.items(): - print(f"- {source}: {error}") - - if ast_errors: - print("") - print("AST failures:") - for source, error in ast_errors.items(): - print(f"- {source}: {error}") - - if warnings: - print("") - print("Warnings:") - for warning in warnings: - print(f"- {warning}") - - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/.kilo/skills/benchmarks-compare-revisions/SKILL.md b/.kilo/skills/benchmarks-compare-revisions/SKILL.md deleted file mode 100644 index 0d77a3a..0000000 --- a/.kilo/skills/benchmarks-compare-revisions/SKILL.md +++ /dev/null @@ -1,223 +0,0 @@ ---- -name: benchmarks-compare-revisions -description: Compare benchmark performance between two git revisions via Google Benchmark compare.py, with optional hardware-counter comparison from diagnostic libpfm builds. ---- - -# Benchmarks Compare Revisions Skill - -Use this skill to compare performance between two git revisions. - -This workflow now depends on: - -1. `.kilo/skills/benchmarks-affected/SKILL.md` to determine affected benchmark targets/functions and produce a benchmark filter. -2. `.kilo/skills/benchmarks/SKILL.md` for build/run operational details. - -## Goal - -Build two separate benchmark binaries using short commit hashes as build suffixes, compare timing results with Google Benchmark compare.py, and optionally compare hardware counters across the same revisions. - -## Step 0 — Choose revisions, hashes, and options - -Pick a baseline and a contender revision. Use short commit hashes to suffix build directories so builds do not collide. - -Optional behavior flags: - -- `COLLECT_COUNTERS=1` to enable hardware-counter collection and analysis in addition to timing comparison. -- `COLLECT_COUNTERS=0` to run timing-only comparison. - -Counter collection is Linux-only and requires: - -- diagnostic builds with `BENCHMARK_ENABLE_LIBPFM=ON` -- perf permissions on the host (for access to performance counters) - -Example: -```bash -BASELINE=abc1234 -CONTENDER=def5678 -``` - -## Step 1 — Compute affected benchmark scope first - -Run `benchmarks-affected` from the contender checkout to derive the compare scope. - -Do not duplicate `benchmarks-affected` internals here (compile database selection, AST analysis, or fallback heuristics). Follow that skill directly and consume only its outputs. - -Inputs to pass through: - -- `--baseline ${BASELINE}` -- optional compile-commands path if auto-detection is not desired -- optional output format (`json` recommended for parsing) - -Consume these outputs from `benchmarks-affected`: - -- `suggested_filter_regex` -> set `FILTER` -- `affected_benchmark_targets` -> optionally constrain which benchmark binary/binaries to run -- `affected_benchmarks` -> function-level scope for validation/reporting - -If `FILTER` is empty, fall back to full benchmark binary compare (conservative mode). - -## Step 2 — Build both revisions - -Use the existing benchmarks skill build steps, but set the build suffix to include the short hash for each revision. - -Always build Release timing binaries. - -If `COLLECT_COUNTERS=1`, also build diagnostic binaries (RelWithDebInfo + libpfm) for both revisions. - -```bash -# Baseline -BUILD_SUFFIX=bench_${BASELINE} -git checkout ${BASELINE} -# Follow .kilo/skills/benchmarks/SKILL.md timing build instructions with this suffix -# If COLLECT_COUNTERS=1, also follow the diagnostic build instructions with this suffix - -# Contender -BUILD_SUFFIX=bench_${CONTENDER} -git checkout ${CONTENDER} -# Follow .kilo/skills/benchmarks/SKILL.md timing build instructions with this suffix -# If COLLECT_COUNTERS=1, also follow the diagnostic build instructions with this suffix -``` - -Expected build trees: - -- Timing: `build/benchmarks-all_bench_` -- Counters (optional): `build/benchmarks-diagnostic_bench_` - -## Step 3 — Compare using compare.py - -Use Google Benchmark compare tooling with a JSON-first flow to avoid long-running binary-vs-binary retries. - -Locate compare.py from the Google Benchmark dependency (installed under the build tree): -```bash -COMPARE_PY=build/benchmarks-all_bench_${BASELINE}/_deps/googlebenchmark-src/tools/compare.py -``` - -Verify Python deps once (compare.py imports numpy/scipy): -```bash -python3 -c "import numpy, scipy" -``` - -Generate baseline/contender JSON sequentially with explicit file outputs: -```bash -BASE_JSON=/tmp/bench_${BASELINE}.json -CONT_JSON=/tmp/bench_${CONTENDER}.json - -build/benchmarks-all_bench_${BASELINE}/benchmarks \ - --benchmark_report_aggregates_only=true \ - --benchmark_display_aggregates_only=true \ - --benchmark_format=json \ - --benchmark_out=${BASE_JSON} > /tmp/bench_${BASELINE}.log 2>&1 - -build/benchmarks-all_bench_${CONTENDER}/benchmarks \ - --benchmark_report_aggregates_only=true \ - --benchmark_display_aggregates_only=true \ - --benchmark_format=json \ - --benchmark_out=${CONT_JSON} > /tmp/bench_${CONTENDER}.log 2>&1 -``` - -Validate JSON before comparing: -```bash -python3 -m json.tool ${BASE_JSON} > /dev/null -python3 -m json.tool ${CONT_JSON} > /dev/null -``` - -Run the comparison: -```bash -python3 ${COMPARE_PY} -a benchmarks ${BASE_JSON} ${CONT_JSON} -``` - -Use the affected filter from Step 1 when generating JSON files: -```bash -if [ -n "${FILTER}" ]; then - FILTER_ARG="--benchmark_filter=${FILTER}" -else - FILTER_ARG="" -fi - -build/benchmarks-all_bench_${BASELINE}/benchmarks ${FILTER_ARG} --benchmark_report_aggregates_only=true --benchmark_display_aggregates_only=true ... -build/benchmarks-all_bench_${CONTENDER}/benchmarks ${FILTER_ARG} --benchmark_report_aggregates_only=true --benchmark_display_aggregates_only=true ... -``` - -## Step 3b — Compare hardware counters (optional, Linux only) - -Run this step only when `COLLECT_COUNTERS=1`. - -1. Preflight first with one tiny counter-enabled benchmark run from a diagnostic binary. If output includes warnings such as `Failed to get a file descriptor for performance counter`, mark counters unavailable and skip counter collection. -2. Run baseline and contender diagnostic binaries sequentially with explicit JSON outputs and the same filter scope: - -```bash -BASE_COUNTERS_JSON=/tmp/bench_counters_${BASELINE}.json -CONT_COUNTERS_JSON=/tmp/bench_counters_${CONTENDER}.json - -build/benchmarks-diagnostic_bench_${BASELINE}/benchmarks \ - ${FILTER_ARG} \ - --benchmark_counters_tabular=true \ - --benchmark_format=json \ - --benchmark_out=${BASE_COUNTERS_JSON} > /tmp/bench_counters_${BASELINE}.log 2>&1 - -build/benchmarks-diagnostic_bench_${CONTENDER}/benchmarks \ - ${FILTER_ARG} \ - --benchmark_counters_tabular=true \ - --benchmark_format=json \ - --benchmark_out=${CONT_COUNTERS_JSON} > /tmp/bench_counters_${CONTENDER}.log 2>&1 -``` - -3. Validate JSON files before consuming: - -```bash -python3 -m json.tool ${BASE_COUNTERS_JSON} > /dev/null -python3 -m json.tool ${CONT_COUNTERS_JSON} > /dev/null -``` - -4. Collect and compare these counter families when present: - -- `instructions`, `cycles` -- `cache-misses`, `cache-references` -- `branch-misses`, `branches` -- `L1-dcache-load-misses` - -5. Compute derived metrics when denominators are non-zero: - -- IPC = `instructions / cycles` -- Cache miss rate = `cache-misses / cache-references` -- Branch mispredict rate = `branch-misses / branches` - -6. Pair baseline and contender rows by benchmark name, compute deltas, and flag anomalies where timing direction conflicts with key counter direction. - -7. Emit a canonical summary table for downstream consumers: - -```markdown -| Benchmark | IPC (base -> new) | Cache Miss Rate (base -> new) | Branch Mispredict (base -> new) | Anomaly? | -|---|---:|---:|---:|---| -``` - -## Retry and Timeout Policy - -1. Run benchmarks sequentially; do not background with `nohup`/`&`. -2. If a run times out, narrow filter and retry once. -3. Maximum retries per benchmark group: 1. -4. If still failing, emit blocked/partial findings instead of repeated attempts. - -Apply this policy to both timing and counter runs. - -## Step 4 — Record findings - -Capture and return: - -- compare.py output (terminal transcript or redirected file) -- effective filter used -- timing JSON artifacts for baseline and contender -- `counters_available` (`true`/`false`) -- if `counters_available=false`, a reason string (unsupported OS, missing libpfm, perf permission denied, preflight failure) -- if counters are available: counter JSON artifacts, derived metrics table, and anomaly list - -## Best Practices / Guardrails - -1. **Release only**: never compare Debug binaries. -2. **Short hash suffixes**: keep build dirs isolated per revision (example: `bench_`). -3. **Same host, same conditions**: do not compare across different machines or power profiles. -4. **Filter from analysis**: use `benchmarks-affected` output instead of hand-crafted filters whenever possible. -5. **Pin frequency**: for stable numbers, follow benchmark skill guidance on CPU governor. -6. **Counter collection is optional and Linux-only**: when unavailable, return timing-only outputs with `counters_available=false`. -7. **Always preflight counters**: do not run full counter collection if preflight fails. -8. **Keep build types separated**: timing uses `benchmarks-all_*` Release builds; counters use `benchmarks-diagnostic_*` RelWithDebInfo builds; never Debug. diff --git a/.kilo/skills/benchmarks/SKILL.md b/.kilo/skills/benchmarks/SKILL.md deleted file mode 100644 index ea050a4..0000000 --- a/.kilo/skills/benchmarks/SKILL.md +++ /dev/null @@ -1,195 +0,0 @@ ---- -name: benchmarks -description: Run Google Benchmark binaries for the Pixie project, including filtering, hardware counters, and perf profiling. ---- - -# Benchmarks Skill - -You now have expertise in running and interpreting Pixie benchmarks. Follow these workflows: - -## Build Directory Convention - -Use a short commit hash suffix for committed revisions: - -```bash -BUILD_SUFFIX=$(git rev-parse --short HEAD) -``` - -If the worktree has uncommitted changes, append a descriptive suffix so results -cannot be confused with a clean HEAD build: - -```bash -BUILD_SUFFIX=$(git rev-parse --short HEAD)-dirty -``` - -If not a git repository, use - -```bash -BUILD_SUFFIX=agent -``` - -## CRITICAL: Never Run Benchmarks from a Debug Build - -> **Always pass `--config Release` (or `--config RelWithDebInfo`) to `cmake --build`.** -> Multi-config generators (MSVC, Xcode) default to `Debug` if no `--config` is given. -> Google Benchmark will print `***WARNING*** Library was built as DEBUG` and timings will -> be 3-10x slower and meaningless. Always verify the binary path contains `Release/` or -> `RelWithDebInfo/`, never `Debug/`. - -## Step 1 — Build - -If benchmarks affected by the changes are easily tractable build only related targets. - -**Pure timing (benchmarks, Release):** -```bash -cmake -B build/benchmarks_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=Release -DPIXIE_BENCHMARKS=ON -cmake --build build/benchmarks_${BUILD_SUFFIX} --config Release -j -``` - -**Hardware counters / verbose report (benchmarks-diagnostic, RelWithDebInfo, Linux only):** -```bash -cmake -B build/benchmarks-diagnostic_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=RelWithDebInfo -DPIXIE_BENCHMARKS=ON -DBENCHMARK_ENABLE_LIBPFM=ON -DPIXIE_DIAGNOSTICS=ON -cmake --build build/benchmarks-diagnostic_${BUILD_SUFFIX} --config RelWithDebInfo -j -``` - -## Step 2 — Run - -Prefer running benchmarks with filtering passing the benchmarks that should be affected. - -Execution guardrails: -- Run benchmark commands sequentially in CI. -- Avoid background jobs (`nohup`, `&`) for benchmark collection. -- Always write machine-readable results with `--benchmark_out` when data is later parsed. - -### Available benchmark binaries - -| Binary | What it covers | -|--------|---------------| -| `benchmarks` | BitVector rank/select | -| `bench_rmm` | RmM Tree operations | -| `bench_rmm_sdsl` | RmM vs sdsl-lite comparison | -| `louds_tree_benchmarks` | LOUDS Tree traversal | -| `alignment_comparison` | Memory alignment effects | - -Binary paths vary by generator type: - -| Generator | Path pattern | -|-----------|-------------| -| MSVC / Xcode (multi-config) | `build/_/Release/` | -| Ninja / Make (single-config) | `build/_/` | - -### Run all benchmarks in a binary - -```bash -# Multi-config (MSVC/Xcode) -build/benchmarks_${BUILD_SUFFIX}/Release/benchmarks - -# Single-config (Ninja/Make) -build/benchmarks_${BUILD_SUFFIX}/benchmarks -``` - -### Filter benchmarks with a regex (FILTER parameter) - -```bash -FILTER="BM_Rank" # change to match benchmark names, e.g. "BM_Select", "BM_Louds", "" - -# Multi-config -build/benchmarks_${BUILD_SUFFIX}/Release/benchmarks --benchmark_filter="${FILTER}" - -# Single-config -build/benchmarks_${BUILD_SUFFIX}/benchmarks --benchmark_filter="${FILTER}" -``` - -Examples: -```bash -# Only rank benchmarks -... --benchmark_filter="BM_Rank" - -# Only select on non-interleaved layouts -... --benchmark_filter="BM_Select.*NonInterleaved" - -# List all available benchmark names without running -... --benchmark_list_tests=true -``` - -### Run with hardware counters (benchmarks-diagnostic build, Linux only) - -The `--benchmark_perf_counters` flag requests hardware counter collection via libpfm. Counter names are platform-specific but common ones include `CYCLES`, `INSTRUCTIONS`, `CACHE-MISSES`, `CACHE-REFERENCES`, `BRANCH-MISSES`, `BRANCH-INSTRUCTIONS`. - -```bash -build/benchmarks-diagnostic_${BUILD_SUFFIX}/RelWithDebInfo/benchmarks \ - --benchmark_filter="${FILTER}" \ - --benchmark_perf_counters=CYCLES,INSTRUCTIONS,CACHE-MISSES \ - --benchmark_counters_tabular=true -``` - -### Save results to file - -```bash -build/benchmarks_${BUILD_SUFFIX}/Release/benchmarks \ - --benchmark_filter="${FILTER}" \ - --benchmark_report_aggregates_only=true \ - --benchmark_display_aggregates_only=true \ - --benchmark_format=json \ - --benchmark_out=results.json -``` - -Validate output before consuming: -```bash -python3 -m json.tool results.json > /dev/null -``` - -## Step 3 — Profile with perf (Linux only) - -Use when hardware counters alone are not enough and you need a full call-graph profile for post-processing. - -**Record:** -```bash -perf record -g -F 999 \ - build/benchmarks-diagnostic_${BUILD_SUFFIX}/benchmarks \ - --benchmark_filter="${FILTER}" \ - --benchmark_min_time=5s -``` - -**Quick report (terminal):** -```bash -perf report --stdio -``` - -**Flame graph (requires FlameGraph scripts):** -```bash -perf script | stackcollapse-perf.pl | flamegraph.pl > flamegraph.html -``` - -**Export for external tools (Hotspot, Firefox Profiler):** -```bash -perf script -F +pid > perf.data.txt -# or open with `hotspot perf.data` -``` - -## Useful Benchmark Flags - -| Flag | Purpose | -|------|---------| -| `--benchmark_filter=` | Run only matching benchmarks | -| `--benchmark_list_tests=true` | List names without running | -| `--benchmark_repetitions=` | Repeat each benchmark n times | -| `--benchmark_min_time=` | Minimum run time per benchmark | -| `--benchmark_format=json` | Machine-readable output | -| `--benchmark_out=` | Save output to file | -| `--benchmark_perf_counters=CYCLES,INSTRUCTIONS,...` | Collect hardware perf counters (requires libpfm build) | -| `--benchmark_counters_tabular=true` | Align user/perf counter columns into a table | -| `--benchmark_time_unit=ms` | Change display unit (ns/us/ms/s) | - -## Best Practices - -1. **Never run from a Debug binary**: always use `--config Release` at build time; check path contains `Release/` -2. **Use benchmarks for clean timing**: Release optimizations, no debug info, no libpfm overhead -3. **Use benchmarks-diagnostic for hardware counters**: RelWithDebInfo + libpfm; Linux only -4. **Use perf for deep profiling**: when counters point to a hotspot but don't explain it -5. **Pin CPU frequency** before timing runs: `sudo cpupower frequency-set -g performance` -6. **Filter to reduce noise**: narrow the filter regex to the benchmark under investigation -7. **Save JSON output** when comparing before/after changes: use `--benchmark_out` and diff the files -8. **Fail fast on environment issues**: precheck Python deps used by compare tooling (`numpy`, `scipy`) -9. **Use explicit retry limits**: on timeout, narrow scope and retry once; avoid repeated full-suite attempts -10. **Preflight perf counters**: run a tiny counter-enabled benchmark first; if counters unavailable, skip counter workflow diff --git a/.kilo/skills/cmake/SKILL.md b/.kilo/skills/cmake/SKILL.md deleted file mode 100644 index d3c7096..0000000 --- a/.kilo/skills/cmake/SKILL.md +++ /dev/null @@ -1,103 +0,0 @@ ---- -name: cmake -description: Compile and build CMake projects, including configuring build types, options, and running test binaries. ---- - -# CMake Build Skill - -You now have expertise in building and configuring CMake projects. Follow these workflows: - -## Build Directory Convention - -Use a short commit hash suffix for committed revisions: - -```bash -BUILD_SUFFIX=$(git rev-parse --short HEAD) -``` - -If the worktree has uncommitted changes, append a descriptive suffix so generated -artifacts cannot be confused with a clean HEAD build: - -```bash -BUILD_SUFFIX=$(git rev-parse --short HEAD)-dirty -``` - -If not a git repository, use - -```bash -BUILD_SUFFIX=agent -``` - -Build directories follow the pattern `build/_`. - -## Using Presets (Preferred When Available) - -> **Important**: `cmake --preset` sets cache variables and generator but its `binaryDir` cannot be -> overridden from the command line. To use a preset's settings with a custom build dir, pass the -> relevant `-D` flags explicitly together with `-B`. Use `--preset` only to discover what flags a -> preset applies. - -**List available presets:** -```bash -cmake --list-presets -``` - -**Replicate a preset's settings with a custom suffix build dir:** - -Release: -```bash -cmake -B build/release_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=Release -cmake --build build/release_${BUILD_SUFFIX} -j -``` - -Debug: -```bash -cmake -B build/debug_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=Debug -cmake --build build/debug_${BUILD_SUFFIX} -j -``` - -AddressSanitizer (mirrors `asan` preset): -```bash -cmake -B build/asan_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=Debug -DPIXIE_BENCHMARKS=OFF -DENABLE_ADDRESS_SANITIZER=ON -cmake --build build/asan_${BUILD_SUFFIX} -j -``` - -Coverage (mirrors `coverage` preset): -```bash -cmake -B build/coverage_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=Debug -DPIXIE_BENCHMARKS=OFF -DPIXIE_COVERAGE=ON -cmake --build build/coverage_${BUILD_SUFFIX} -j -``` - -Benchmarks (mirrors `benchmarks` preset): -```bash -cmake -B build/benchmarks_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=Release -DPIXIE_BENCHMARKS=ON -cmake --build build/benchmarks_${BUILD_SUFFIX} -j -``` - -## Additional Feature Options - -**Disable AVX-512 (use AVX2 fallback):** -```bash -cmake -B build/release_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=Release -DDISABLE_AVX512=ON -cmake --build build/release_${BUILD_SUFFIX} -j -``` - -**Custom march flag:** -```bash -cmake -B build/release_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=Release -DMARCH=icelake-client -cmake --build build/release_${BUILD_SUFFIX} -j -``` - -**Tests only (no benchmarks or third-party deps):** -```bash -cmake -B build/release_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=Release -DPIXIE_BENCHMARKS=OFF -cmake --build build/release_${BUILD_SUFFIX} -j -``` - -## Best Practices - -1. **Use out-of-source builds**: Keep build artifacts in `build/_` directories -2. **Presets fix binaryDir**: `--preset` cannot be combined with `-B` to change the build dir; replicate `-D` flags manually with `-B` instead -3. **Reconfigure when options change**: Rerun the `cmake -B ...` step when toggling options -4. **Clean build directory when needed**: Delete the entire build folder for a fresh configuration -5. **Match build type to task**: Release for performance work, Debug/ASan for correctness diff --git a/.kilo/skills/paper-search/SKILL.md b/.kilo/skills/paper-search/SKILL.md deleted file mode 100644 index 27f108c..0000000 --- a/.kilo/skills/paper-search/SKILL.md +++ /dev/null @@ -1,106 +0,0 @@ ---- -name: paper-search -description: "Search for academic papers across Semantic Scholar, arXiv, and CrossRef APIs. Returns unified results with title, authors, year, abstract, DOI, venue, and citation counts. Integrates with Zotero MCP tools for adding found papers to a Zotero library and generating BibTeX entries. Use when the user asks to find papers, search for related work, look up a DOI, or discover references on a topic." ---- - -# Paper Search - -Search external academic APIs for papers. Provides a unified interface across Semantic Scholar, arXiv, and CrossRef with optional Zotero integration. - -## Workflow - -### 1. Search for Papers - -Run the search script from the skill's `scripts/` directory: - -```bash -python3 scripts/search_papers.py --query "topic" --source semantic_scholar --limit 10 --format compact -``` - -Available sources: -- `semantic_scholar` — Default. Best for comprehensive search with citation counts. -- `arxiv` — Best for preprints and recent unpublished work. -- `crossref` — Best for published works and DOI-based metadata. -- `all` — Query all three sources (slower, results combined). - -Output formats: -- `json` — Full JSON output (default). Good for programmatic use. -- `compact` — Human-readable summary with title, authors, year, venue, citations, and truncated abstract. - -### 2. DOI Lookup - -Look up a specific paper by DOI: - -```bash -python3 scripts/search_papers.py --doi "10.1145/1234567.1234568" --format compact -``` - -### 3. Download PDFs - -Download open-access PDFs directly from search results: - -```bash -python3 scripts/search_papers.py --query "wavelet tree" --source arxiv --limit 3 --download ~/papers -``` - -- arXiv papers always have PDFs available. -- Semantic Scholar provides `openAccessPdf` URLs when available. -- CrossRef may provide PDF links via publisher APIs. - -The `--download` flag adds a `downloaded_path` field to each result in JSON output. - -### 4. Add to Zotero - -**Option A: Via DOI/URL (metadata only)** - -After finding relevant papers, add them to Zotero using the Zotero MCP tools: - -- `zotero_add_by_doi` — Preferred when DOI is available. Fetches full metadata from CrossRef. -- `zotero_add_by_url` — Use for arXiv papers or when only a URL is available. - -**Option B: Via downloaded PDF (metadata + attachment)** - -Download the PDF first, then add to Zotero with the PDF file: - -```bash -# Step 1: Download PDFs and get paths in JSON -python3 scripts/search_papers.py --doi "10.1007/978-3-540-73420-8_13" --download ~/papers --format json - -# Step 2: Use zotero_zotero_add_from_file with the downloaded_path -``` - -The agent should call `zotero_zotero_add_from_file` with the `downloaded_path` from the JSON output. This attaches the PDF to the Zotero item and attempts DOI-based metadata extraction. - -**Option C: Download + Zotero in one step** - -Use `--zotero` to download PDFs with paths formatted for easy Zotero import: - -```bash -python3 scripts/search_papers.py -q "succinct data structures" -s arxiv -n 3 --zotero --download ~/papers -``` - -After adding papers, update the semantic search database: - -``` -zotero_update_search_database -``` - -### 5. Generate BibTeX - -For papers already in Zotero, use `zotero_get_item_metadata` with `format: "bibtex"` to get BibTeX entries. Alternatively, use `zotero_fetch` for full metadata. - -For papers NOT in Zotero, BibTeX can be constructed from the search results' JSON fields (`authors`, `year`, `title`, `venue`, `doi`). - -## Guidance - -- Start with `semantic_scholar` for general queries — it has the broadest coverage and citation data. -- Use `arxiv` when looking for very recent work or preprints in CS/ML/physics. -- Use `crossref` for DOI lookups or when Semantic Scholar returns no results. -- When using `--source all`, results may contain duplicates (same paper from different sources). Deduplicate by DOI or title similarity. -- Citation counts are approximate and may differ across sources. -- arXiv results return the arXiv ID (e.g., `2301.12345`) which can be used with `zotero_add_by_url` via `https://arxiv.org/abs/2301.12345`. - -## API Quirks - -- **arXiv `atom:id` is NOT a DOI** — it contains an arXiv URL like `http://arxiv.org/abs/2301.12345`. Store the extracted ID in `arxiv_id` only; set `doi` to `None` for arXiv results. Writing the arXiv URL into `doi` produces invalid DOI metadata downstream (e.g., Zotero import). -- **CrossRef `select` must include `link`** — the `link` field is needed for `pdf_url` extraction. If omitted from `select`, the API won't return link metadata and `pdf_url` will silently be empty for all CrossRef results. diff --git a/.kilo/skills/paper-search/references/api_reference.md b/.kilo/skills/paper-search/references/api_reference.md deleted file mode 100644 index dcb5aa5..0000000 --- a/.kilo/skills/paper-search/references/api_reference.md +++ /dev/null @@ -1,32 +0,0 @@ -# External Paper Search APIs - -## Semantic Scholar - -- **Base URL**: `https://api.semanticscholar.org/graph/v1/` -- **Rate limit**: 1 req/sec without API key, 10 req/sec with key -- **No auth required** for basic usage -- **Fields**: title, authors, year, abstract, externalIds (DOI, ArXiv), venue, citationCount -- **Best for**: Comprehensive academic search with citation counts - -## arXiv - -- **Base URL**: `http://export.arxiv.org/api/query` -- **Rate limit**: Be nice, ~3 sec between requests -- **No auth required** -- **Returns**: XML (Atom feed) -- **Best for**: Preprints, recent work not yet published - -## CrossRef - -- **Base URL**: `https://api.crossref.org/` -- **Rate limit**: 50 req/sec with polite pool (include `mailto` header) -- **No auth required** -- **Best for**: DOI lookup, published works, metadata enrichment - -## Zotero Integration - -After finding papers via external search, use Zotero MCP tools: - -1. `zotero_add_by_doi` — Add paper by DOI (fetches metadata from CrossRef) -2. `zotero_add_by_url` — Add paper by URL (arXiv, DOI URLs) -3. `zotero_update_search_database` — Update semantic search index after adding diff --git a/.kilo/skills/paper-search/scripts/search_papers.py b/.kilo/skills/paper-search/scripts/search_papers.py deleted file mode 100644 index c65351d..0000000 --- a/.kilo/skills/paper-search/scripts/search_papers.py +++ /dev/null @@ -1,333 +0,0 @@ -#!/usr/bin/env python3 -"""Search external APIs for academic papers. - -Sources: Semantic Scholar, arXiv, CrossRef. -Outputs unified JSON to stdout. - -Usage: - python3 search_papers.py --query "wavelet tree succinct" --source semantic_scholar --limit 10 - python3 search_papers.py --query "succinct data structures" --source arxiv --limit 5 - python3 search_papers.py --doi "10.1145/123" --source crossref - python3 search_papers.py --query "rank select" --source all --limit 5 - python3 search_papers.py --query "wavelet tree" --source arxiv --limit 1 --download ~/papers - python3 search_papers.py --doi "10.1007/978-3-540-73420-8_13" --download ~/papers --zotero -""" - -import argparse -import json -import os -import re -import sys -import time -import urllib.error -import urllib.parse -import urllib.request -from pathlib import Path -from typing import Any - - -def _get(url: str, headers: dict[str, str] | None = None, timeout: int = 30, - retries: int = 2) -> dict: - for attempt in range(retries + 1): - req = urllib.request.Request(url, headers=headers or {}) - try: - with urllib.request.urlopen(req, timeout=timeout) as resp: - return json.loads(resp.read().decode()) - except urllib.error.HTTPError as e: - if e.code == 429 and attempt < retries: - wait = 2 ** attempt - print(f"Rate limited, retrying in {wait}s...", file=sys.stderr) - time.sleep(wait) - continue - print(f"HTTP {e.code}: {e.reason} for {url}", file=sys.stderr) - return {} - except urllib.error.URLError as e: - print(f"URL error: {e.reason} for {url}", file=sys.stderr) - return {} - return {} - - -def search_semantic_scholar(query: str, limit: int = 10) -> list[dict[str, Any]]: - """Search Semantic Scholar API.""" - params = urllib.parse.urlencode({ - "query": query, - "limit": limit, - "fields": "title,authors,year,abstract,externalIds,venue,publicationDate,citationCount,url,openAccessPdf", - }) - url = f"https://api.semanticscholar.org/graph/v1/paper/search?{params}" - data = _get(url, headers={"Accept": "application/json"}) - results = [] - for paper in data.get("data", []): - ext_ids = paper.get("externalIds") or {} - pdf_info = paper.get("openAccessPdf") or {} - results.append({ - "source": "semantic_scholar", - "title": paper.get("title", ""), - "authors": [a.get("name", "") for a in paper.get("authors", [])], - "year": paper.get("year"), - "abstract": paper.get("abstract", ""), - "doi": ext_ids.get("DOI"), - "arxiv_id": ext_ids.get("ArXiv"), - "venue": paper.get("venue", ""), - "citation_count": paper.get("citationCount"), - "url": paper.get("url", ""), - "pdf_url": pdf_info.get("url"), - }) - return results - - -def search_arxiv(query: str, limit: int = 10) -> list[dict[str, Any]]: - """Search arXiv API.""" - words = query.split() - if len(words) == 1: - search_term = f"all:{query}" - elif len(words) == 2: - # Phrase search for 2-word queries - search_term = f'all:"{query}"' - else: - # Use OR of phrase and individual terms for 3+ words - # This catches exact phrase matches AND papers with all terms - phrase = f'all:"{query}"' - and_terms = " AND ".join(f"all:{w}" for w in words) - search_term = f"({phrase}) OR ({and_terms})" - params = urllib.parse.urlencode({ - "search_query": search_term, - "start": 0, - "max_results": limit, - }) - url = f"http://export.arxiv.org/api/query?{params}" - req = urllib.request.Request(url) - try: - with urllib.request.urlopen(req, timeout=30) as resp: - xml_data = resp.read().decode() - except (urllib.error.URLError, urllib.error.HTTPError) as e: - print(f"arXiv API error: {e}", file=sys.stderr) - return [] - - import xml.etree.ElementTree as ET - root = ET.fromstring(xml_data) - ns = {"atom": "http://www.w3.org/2005/Atom"} - results = [] - for entry in root.findall("atom:entry", ns): - title = entry.findtext("atom:title", "", ns).strip().replace("\n", " ") - abstract = entry.findtext("atom:summary", "", ns).strip().replace("\n", " ") - authors = [a.findtext("atom:name", "", ns) for a in entry.findall("atom:author", ns)] - published = entry.findtext("atom:published", "", ns) - year = int(published[:4]) if published else None - arxiv_id = "" - for link in entry.findall("atom:link", ns): - href = link.get("href", "") - if "arxiv.org/abs/" in href: - arxiv_id = href.split("/abs/")[-1] - break - results.append({ - "source": "arxiv", - "title": title, - "authors": authors, - "year": year, - "abstract": abstract, - "doi": None, - "arxiv_id": arxiv_id, - "venue": "arXiv", - "citation_count": None, - "url": f"https://arxiv.org/abs/{arxiv_id}" if arxiv_id else "", - "pdf_url": f"https://arxiv.org/pdf/{arxiv_id}" if arxiv_id else None, - }) - return results - - -def search_crossref(query: str, limit: int = 10) -> list[dict[str, Any]]: - """Search CrossRef API.""" - params = urllib.parse.urlencode({ - "query": query, - "rows": limit, - "select": "DOI,title,author,published-print,abstract,container-title,is-referenced-by-count,URL,type,link", - }) - url = f"https://api.crossref.org/works?{params}" - data = _get(url, headers={"Accept": "application/json"}) - results = [] - for item in data.get("message", {}).get("items", []): - title_list = item.get("title", []) - title = title_list[0] if title_list else "" - authors = [] - for a in item.get("author", []): - name = f"{a.get('given', '')} {a.get('family', '')}".strip() - if name: - authors.append(name) - pub_date = item.get("published-print", {}).get("date-parts", [[None]]) - year = pub_date[0][0] if pub_date and pub_date[0] else None - venue_list = item.get("container-title", []) - venue = venue_list[0] if venue_list else "" - pdf_url = None - for link in item.get("link", []): - if "pdf" in link.get("content-type", ""): - pdf_url = link.get("URL") - break - results.append({ - "source": "crossref", - "title": title, - "authors": authors, - "year": year, - "abstract": item.get("abstract", ""), - "doi": item.get("DOI"), - "arxiv_id": None, - "venue": venue, - "citation_count": item.get("is-referenced-by-count"), - "url": item.get("URL", ""), - "pdf_url": pdf_url, - }) - return results - - -def lookup_doi(doi: str) -> dict[str, Any] | None: - """Look up a single paper by DOI via CrossRef.""" - url = f"https://api.crossref.org/works/{urllib.parse.quote(doi, safe='')}" - data = _get(url) - item = data.get("message") - if not item: - return None - title_list = item.get("title", []) - title = title_list[0] if title_list else "" - authors = [] - for a in item.get("author", []): - name = f"{a.get('given', '')} {a.get('family', '')}".strip() - if name: - authors.append(name) - pub_date = item.get("published-print", {}).get("date-parts", [[None]]) - year = pub_date[0][0] if pub_date and pub_date[0] else None - venue_list = item.get("container-title", []) - venue = venue_list[0] if venue_list else "" - pdf_url = None - for link in item.get("link", []): - if "pdf" in link.get("content-type", ""): - pdf_url = link.get("URL") - break - return { - "source": "crossref", - "title": title, - "authors": authors, - "year": year, - "abstract": item.get("abstract", ""), - "doi": item.get("DOI"), - "arxiv_id": None, - "venue": venue, - "citation_count": item.get("is-referenced-by-count"), - "url": item.get("URL", ""), - "pdf_url": pdf_url, - } - - -SOURCES = { - "semantic_scholar": search_semantic_scholar, - "arxiv": search_arxiv, - "crossref": search_crossref, -} - - -def _sanitize_filename(title: str) -> str: - """Generate a clean filename from paper title.""" - name = re.sub(r'[^\w\s-]', '', title.lower()) - name = re.sub(r'[\s]+', '_', name.strip()) - return name[:80] - - -def download_pdf(url: str, output_dir: str, paper: dict[str, Any]) -> str | None: - """Download a PDF and return the local path.""" - filename = _sanitize_filename(paper.get("title", "paper")) + ".pdf" - output_path = Path(output_dir) / filename - output_path.parent.mkdir(parents=True, exist_ok=True) - - req = urllib.request.Request(url, headers={ - "User-Agent": "Mozilla/5.0 (academic paper-search script)" - }) - try: - with urllib.request.urlopen(req, timeout=60) as resp: - content_type = resp.headers.get("Content-Type", "") - if "pdf" not in content_type and "octet-stream" not in content_type: - print(f"Warning: unexpected content type '{content_type}' for {url}", - file=sys.stderr) - with open(output_path, "wb") as f: - f.write(resp.read()) - print(f"Downloaded: {output_path}", file=sys.stderr) - return str(output_path) - except (urllib.error.URLError, urllib.error.HTTPError) as e: - print(f"Download failed for {url}: {e}", file=sys.stderr) - return None - - -def main(): - parser = argparse.ArgumentParser(description="Search for academic papers") - parser.add_argument("--query", "-q", help="Search query") - parser.add_argument("--doi", help="Look up a specific DOI") - parser.add_argument("--source", "-s", default="semantic_scholar", - choices=["semantic_scholar", "arxiv", "crossref", "all"], - help="API source (default: semantic_scholar)") - parser.add_argument("--limit", "-n", type=int, default=10, - help="Max results per source (default: 10)") - parser.add_argument("--format", "-f", default="json", - choices=["json", "compact"], - help="Output format (default: json)") - parser.add_argument("--download", "-d", metavar="DIR", - help="Download PDFs to DIR (requires pdf_url in results)") - parser.add_argument("--zotero", "-z", action="store_true", - help="Download PDFs and output paths for Zotero import (implies --download)") - args = parser.parse_args() - - if not args.query and not args.doi: - parser.error("Either --query or --doi is required") - - if args.zotero and not args.download: - args.download = "." - - results = [] - if args.doi: - paper = lookup_doi(args.doi) - if paper: - results.append(paper) - elif args.source == "all": - for name, func in SOURCES.items(): - try: - results.extend(func(args.query, args.limit)) - except Exception as e: - print(f"Error searching {name}: {e}", file=sys.stderr) - time.sleep(1) - else: - results = SOURCES[args.source](args.query, args.limit) - - if args.download: - for r in results: - pdf_url = r.get("pdf_url") - if pdf_url: - path = download_pdf(pdf_url, args.download, r) - r["downloaded_path"] = path - else: - r["downloaded_path"] = None - - if args.format == "json": - print(json.dumps(results, indent=2)) - else: - for i, r in enumerate(results, 1): - authors = ", ".join(r["authors"][:3]) - if len(r["authors"]) > 3: - authors += " et al." - doi_str = f" DOI: {r['doi']}" if r.get("doi") else "" - arxiv_str = f" arXiv: {r['arxiv_id']}" if r.get("arxiv_id") else "" - cite_str = f" Citations: {r['citation_count']}" if r.get("citation_count") else "" - pdf_str = f" PDF: {r['pdf_url']}" if r.get("pdf_url") else " PDF: N/A" - dl_str = "" - if r.get("downloaded_path"): - dl_str = f" Downloaded: {r['downloaded_path']}" - print(f"[{i}] {r['title']}") - print(f" {authors} ({r.get('year', '?')}) — {r.get('venue', '')}") - print(f" {r.get('url', '')}{doi_str}{arxiv_str}{cite_str}") - print(f" {pdf_str}{dl_str}") - if r.get("abstract"): - abstract = r["abstract"][:200] - if len(r["abstract"]) > 200: - abstract += "..." - print(f" {abstract}") - print() - - -if __name__ == "__main__": - main() diff --git a/.kilo/skills/pdf/SKILL.md b/.kilo/skills/pdf/SKILL.md deleted file mode 100644 index ddbce00..0000000 --- a/.kilo/skills/pdf/SKILL.md +++ /dev/null @@ -1,112 +0,0 @@ ---- -name: pdf -description: Process PDF files - extract text, create PDFs, merge documents. Use when user asks to read PDF, create PDF, or work with PDF files. ---- - -# PDF Processing Skill - -You now have expertise in PDF manipulation. Follow these workflows: - -## Reading PDFs - -**Option 1: Quick text extraction (preferred)** -```bash -# Using pdftotext (poppler-utils) -pdftotext input.pdf - # Output to stdout -pdftotext input.pdf output.txt # Output to file - -# If pdftotext not available, try: -python3 -c " -import fitz # PyMuPDF -doc = fitz.open('input.pdf') -for page in doc: - print(page.get_text()) -" -``` - -**Option 2: Page-by-page with metadata** -```python -import fitz # pip install pymupdf - -doc = fitz.open("input.pdf") -print(f"Pages: {len(doc)}") -print(f"Metadata: {doc.metadata}") - -for i, page in enumerate(doc): - text = page.get_text() - print(f"--- Page {i+1} ---") - print(text) -``` - -## Creating PDFs - -**Option 1: From Markdown (recommended)** -```bash -# Using pandoc -pandoc input.md -o output.pdf - -# With custom styling -pandoc input.md -o output.pdf --pdf-engine=xelatex -V geometry:margin=1in -``` - -**Option 2: Programmatically** -```python -from reportlab.lib.pagesizes import letter -from reportlab.pdfgen import canvas - -c = canvas.Canvas("output.pdf", pagesize=letter) -c.drawString(100, 750, "Hello, PDF!") -c.save() -``` - -**Option 3: From HTML** -```bash -# Using wkhtmltopdf -wkhtmltopdf input.html output.pdf - -# Or with Python -python3 -c " -import pdfkit -pdfkit.from_file('input.html', 'output.pdf') -" -``` - -## Merging PDFs - -```python -import fitz - -result = fitz.open() -for pdf_path in ["file1.pdf", "file2.pdf", "file3.pdf"]: - doc = fitz.open(pdf_path) - result.insert_pdf(doc) -result.save("merged.pdf") -``` - -## Splitting PDFs - -```python -import fitz - -doc = fitz.open("input.pdf") -for i in range(len(doc)): - single = fitz.open() - single.insert_pdf(doc, from_page=i, to_page=i) - single.save(f"page_{i+1}.pdf") -``` - -## Key Libraries - -| Task | Library | Install | -|------|---------|---------| -| Read/Write/Merge | PyMuPDF | `pip install pymupdf` | -| Create from scratch | ReportLab | `pip install reportlab` | -| HTML to PDF | pdfkit | `pip install pdfkit` + wkhtmltopdf | -| Text extraction | pdftotext | `brew install poppler` / `apt install poppler-utils` | - -## Best Practices - -1. **Always check if tools are installed** before using them -2. **Handle encoding issues** - PDFs may contain various character encodings -3. **Large PDFs**: Process page by page to avoid memory issues -4. **OCR for scanned PDFs**: Use `pytesseract` if text extraction returns empty diff --git a/.kilo/skills/setup-cpp-repo/SKILL.md b/.kilo/skills/setup-cpp-repo/SKILL.md deleted file mode 100644 index 808f997..0000000 --- a/.kilo/skills/setup-cpp-repo/SKILL.md +++ /dev/null @@ -1,137 +0,0 @@ ---- -name: setup-cpp-repo -description: Scaffold a new C++20 repository with CMake, Google Test, Google Benchmark, CI workflows, Doxygen docs, and Chromium code style. Use when the user asks to create a new C++ project, set up a C++ library, or initialize a C++ repository with modern tooling. ---- - -# setup-cpp-repo - -## Overview - -This skill generates a complete C++20 project scaffold following the conventions of the Pixie succinct data structures library. The generated repository is header-only by default and includes: - -- CMake build system with presets -- Google Test for unit testing -- Google Benchmark for performance benchmarks -- Doxygen documentation with doxygen-awesome-css theme -- GitHub Actions CI workflows (ASan, lint, coverage, docs) -- Chromium C++ code style via `.clang-format` -- `AGENTS.md` for AI coding assistant guidelines - -## When to Use This Skill - -Use this skill when: -- The user wants to create a new C++ library or project from scratch -- The user asks for a "C++ project template" or "C++ repo setup" -- The user needs CMake + Google Test + benchmark scaffolding -- The user wants to follow Pixie-style conventions (header-only, AVX-512 optional, Doxygen docs) - -Do **not** use this skill when: -- Working with an existing codebase (use the `cmake` skill instead) -- The project is not C++ (use a different skill) -- The user only wants a single file or snippet - -## Workflow - -### Step 1: Gather Parameters - -Ask the user for (or infer from context): -- **Project name** (required): Hyphenated lowercase identifier, e.g., `my-lib` -- **Namespace** (optional): C++ namespace. Defaults to project name with hyphens removed, e.g., `mylib` -- **Output directory** (optional): Where to create the project. Defaults to current directory. - -### Step 2: Run the Generator - -Execute the generation script: - -```bash -python3 .kilo/skills/setup-cpp-repo/scripts/init_cpp_project.py \ - --name \ - [--namespace ] \ - [--output-dir ] -``` - -Example: -```bash -python3 .kilo/skills/setup-cpp-repo/scripts/init_cpp_project.py \ - --name succinct-lib --namespace succinct --output-dir . -``` - -### Step 3: Verify the Scaffold - -After generation, the project structure should look like: - -``` -/ -├── CMakeLists.txt -├── CMakePresets.json -├── .clang-format -├── .gitignore -├── README.md -├── AGENTS.md -├── include/ -│ └── / -│ └── .hpp -├── src/ -│ ├── tests/ -│ │ └── unittests.cpp -│ ├── benchmarks/ -│ │ └── benchmarks.cpp -│ └── docs/ -│ ├── Doxyfile.in -│ └── images/ -├── scripts/ -│ └── coverage_report.sh -└── .github/ - └── workflows/ - ├── build-test.yml - ├── linter.yml - ├── coverage.yml - └── doxygen.yml -``` - -### Step 4: Initial Build and Test - -Change into the project directory and run an initial build to verify everything works: - -```bash -cd -cmake --preset release -cmake --build --preset release -j -./build/release/unittests -``` - -If the build and tests pass, the scaffold is ready. - -### Step 5: Hand Off to cmake Skill - -After project creation, use the **`cmake` skill** (`.kilo/skills/cmake/SKILL.md`) for all subsequent build operations. The `cmake` skill documents: -- Build directory conventions with git short-hash suffixes -- How to replicate preset settings with custom build directories -- AddressSanitizer, coverage, and benchmark workflows -- Best practices for out-of-source builds - -## Customization Guide - -### Adding More Test Executables - -Edit `CMakeLists.txt` and add new `add_executable` blocks under the `if(_TESTS)` section, following the pattern of the existing `unittests` target. - -Update `scripts/coverage_report.sh` to run any new test binaries. - -Update `.github/workflows/build-test.yml` to execute new test binaries in CI. - -### Adding More Benchmark Executables - -Edit `CMakeLists.txt` and add new `add_executable` blocks under the `if(_BENCHMARKS)` section, following the pattern of the existing `benchmarks` target. - -### Adding Third-Party Dependencies - -For header-only libraries, prefer `FetchContent` in `CMakeLists.txt`. For compiled libraries, consider vendoring or using a package manager (Conan, vcpkg). - -### Modifying Doxygen Configuration - -Edit `src/docs/Doxyfile.in`. The generated version is intentionally minimal (only non-default settings). Add or override settings as needed. Run `doxygen -g` to see all available options. - -## Reference - -See `references/project_structure.md` for a detailed breakdown of every generated file and its purpose. diff --git a/.kilo/skills/setup-cpp-repo/references/project_structure.md b/.kilo/skills/setup-cpp-repo/references/project_structure.md deleted file mode 100644 index ea80877..0000000 --- a/.kilo/skills/setup-cpp-repo/references/project_structure.md +++ /dev/null @@ -1,116 +0,0 @@ -# Generated Project Structure Reference - -This document describes every file and directory generated by `init_cpp_project.py` and its purpose. - -## Root Files - -### `CMakeLists.txt` -Main CMake configuration. Defines: -- C++20 standard requirements -- `MARCH` cache variable (defaults to `native`) -- `DISABLE_AVX512` option for SIMD fallback -- `ENABLE_ADDRESS_SANITIZER` option for ASan builds -- `_COVERAGE` option for gcov instrumentation -- Build options: `_TESTS`, `_BENCHMARKS`, `_DIAGNOSTICS`, `_DOCS` -- FetchContent dependencies: Google Test, Google Benchmark, spdlog (diagnostics only), Doxygen theme -- Test executable: `unittests` -- Benchmark executable: `benchmarks` -- Custom target: `docs` (when Doxygen is enabled) - -### `CMakePresets.json` -CMake presets (version 4) with a hidden `base` preset. Defines presets for: -- `debug` — Debug build -- `release` — Release build -- `benchmarks` — Release with benchmarks enabled -- `benchmarks-diagnostic` — RelWithDebInfo with diagnostics and libpfm -- `docs` — Documentation build -- `coverage` — Debug with coverage instrumentation -- `asan` — Debug with AddressSanitizer - -### `.clang-format` -Chromium-based C++ formatting configuration. Simplified from the full Chromium style by removing Windows-specific include priorities and IPC macro block definitions. Key settings: -- `BasedOnStyle: Chromium` -- `Standard: Cpp11` -- `InsertBraces: true` -- `InsertNewlineAtEOF: true` -- `IncludeBlocks: Regroup` with generic priority categories - -### `.gitignore` -Standard C++ project ignores: -- `build/`, `.vscode/`, `Testing/` -- `plans/*`, `venv/`, `docs/*` -- `CMakeUserPresets.json` -- `_deps/`, gcov outputs (`*.gcda`, `*.gcno`, `*.gcov`) - -### `README.md` -Minimal project README used as the Doxygen main page. - -### `AGENTS.md` -Project documentation for AI coding assistants. Contains: -- Project overview and architecture conventions -- Technology stack (C++20, CMake, Google Test, Google Benchmark) -- Build commands with all CMake options -- Testing patterns and style guidelines -- Common tasks for AI agents (adding components, modifying SIMD code, adding tests) -- Performance philosophy - -## Directories - -### `include//` -Header-only library API. Contains a placeholder header (`.hpp`) with: -- Doxygen file documentation -- Example function in the project's namespace -- `#pragma once` guard - -### `src/tests/` -Unit test scaffold. Contains `unittests.cpp` with: -- Google Test includes -- Basic assertion test against the placeholder header -- `gtest_main` supplies the test runner entry point - -### `src/benchmarks/` -Benchmark scaffold. Contains `benchmarks.cpp` with: -- Google Benchmark includes -- Example benchmark using `benchmark::DoNotOptimize` -- `BENCHMARK_MAIN()` macro - -### `src/docs/` -Doxygen configuration. Contains: -- `Doxyfile.in` — Trimmed Doxygen config (~300 lines vs. 1100+ in full). Only non-default settings are specified. Key templated values: - - `PROJECT_NAME` - - `INPUT` (points to `include/` and `README.md`) - - `STRIP_FROM_PATH` (strips source dir from file paths) - - `IMAGE_PATH` - - `HTML_EXTRA_STYLESHEET` (doxygen-awesome-css) - - `USE_MDFILE_AS_MAINPAGE` -- `images/` — Empty directory for documentation images - -### `scripts/` -Utility scripts. Contains: -- `coverage_report.sh` — Runs the `coverage` CMake preset, executes tests, and generates gcov reports. Excludes `_deps/`, `third_party/`, and `src/benchmarks/` from coverage. - -### `.github/workflows/` -CI/CD workflows: - -#### `build-test.yml` -Builds the project with AddressSanitizer and runs unit tests on `ubuntu-latest`. Triggered on pushes and PRs to `main`. - -#### `linter.yml` -Runs `clang-format --dry-run --Werror` on all C/C++ files. Triggered on pushes to `main` and all PRs. - -#### `coverage.yml` -Runs the coverage script and uploads results to Codecov. Also uploads coverage artifacts. Triggered on pushes and PRs to `main`. - -#### `doxygen.yml` -Installs Doxygen, builds documentation with the `docs` preset, and deploys HTML output to GitHub Pages. Triggered on pushes to `main` and manual dispatch. - -## Template Substitution - -All generated files use these placeholders, replaced by the script: - -| Placeholder | Example input | Example output | -|-------------|---------------|----------------| -| `{{PROJECT_NAME}}` | `my-lib` | `my-lib` | -| `{{NAMESPACE}}` | `mylib` | `mylib` | -| `{{PROJECT_NAME_UPPER}}` | `MY_LIB` | `MY_LIB` | -| `{{HEADER_NAME}}` | `my_lib.hpp` | `my_lib.hpp` | diff --git a/.kilo/skills/setup-cpp-repo/scripts/init_cpp_project.py b/.kilo/skills/setup-cpp-repo/scripts/init_cpp_project.py deleted file mode 100644 index c859090..0000000 --- a/.kilo/skills/setup-cpp-repo/scripts/init_cpp_project.py +++ /dev/null @@ -1,1051 +0,0 @@ -#!/usr/bin/env python3 -""" -init_cpp_project.py - Scaffold a new C++20 repository following Pixie conventions. - -Usage: - init_cpp_project.py --name [--namespace ] [--output-dir ] - -Example: - init_cpp_project.py --name my-lib --namespace mylib --output-dir . -""" - -import argparse -import os -import sys -from pathlib import Path - - -# --------------------------------------------------------------------------- -# Helper functions -# --------------------------------------------------------------------------- - -def to_upper(name: str) -> str: - """Convert project name to uppercase with underscores.""" - return name.replace("-", "_").upper() - - -def to_snake(name: str) -> str: - """Convert project name to snake_case for filenames.""" - return name.replace("-", "_") - - -# --------------------------------------------------------------------------- -# Templates -# --------------------------------------------------------------------------- - -CMAKE_LISTS_TXT = """cmake_minimum_required(VERSION 3.18) -project({{PROJECT_NAME}}) - -set(CMAKE_CXX_STANDARD 20) -set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_CXX_EXTENSIONS OFF) - -set(MARCH "native" CACHE STRING "march compiler flag") -add_compile_options("-march=${MARCH}") -message(STATUS "MARCH is '${MARCH}'") - -option(DISABLE_AVX512 "Disable AVX512 instructions" OFF) -if(DISABLE_AVX512) - add_compile_options("-mno-avx512f") - message(STATUS "DISABLE_AVX512 is ON") -endif() - -option(ENABLE_ADDRESS_SANITIZER "Enable AddressSanitizer" OFF) -if(ENABLE_ADDRESS_SANITIZER) - add_compile_options(-fsanitize=address -fno-omit-frame-pointer) - add_link_options(-fsanitize=address) - message(STATUS "AddressSanitizer is ON") -endif() - -option({{PROJECT_NAME_UPPER}}_COVERAGE "Enable coverage instrumentation" OFF) -if({{PROJECT_NAME_UPPER}}_COVERAGE) - add_compile_options(-O0 -g --coverage) - add_link_options(--coverage) - message(STATUS "Coverage instrumentation is ON") -endif() - -# --------------------------------------------------------------------------- -# Build options -# --------------------------------------------------------------------------- -option({{PROJECT_NAME_UPPER}}_TESTS "Build unit tests" ON) -option({{PROJECT_NAME_UPPER}}_BENCHMARKS "Build benchmarks" OFF) -option({{PROJECT_NAME_UPPER}}_DIAGNOSTICS "Include diagnostic logs" OFF) -option({{PROJECT_NAME_UPPER}}_DOCS "Build Doxygen documentation" OFF) - -if({{PROJECT_NAME_UPPER}}_DIAGNOSTICS) - add_compile_definitions({{PROJECT_NAME_UPPER}}_DIAGNOSTICS) - set({{PROJECT_NAME_UPPER}}_DIAGNOSTICS_LIBS spdlog::spdlog_header_only) -endif() - -# --------------------------------------------------------------------------- -# Dependencies (fetched only when needed) -# --------------------------------------------------------------------------- -include(FetchContent) - -if({{PROJECT_NAME_UPPER}}_DIAGNOSTICS) - set(SPDLOG_BUILD_SHARED OFF CACHE BOOL "" FORCE) - set(SPDLOG_BUILD_EXAMPLE OFF CACHE BOOL "" FORCE) - set(SPDLOG_BUILD_TESTING OFF CACHE BOOL "" FORCE) - set(SPDLOG_INSTALL OFF CACHE BOOL "" FORCE) - FetchContent_Declare( - spdlog - GIT_REPOSITORY https://github.com/gabime/spdlog.git - GIT_TAG v1.14.1 - ) - FetchContent_MakeAvailable(spdlog) -endif() - -if({{PROJECT_NAME_UPPER}}_BENCHMARKS) - FetchContent_Declare( - googlebenchmark - GIT_REPOSITORY https://github.com/google/benchmark.git - GIT_TAG v1.9.4 - ) - set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable Google Benchmark tests") - FetchContent_MakeAvailable(googlebenchmark) -endif() - -if({{PROJECT_NAME_UPPER}}_TESTS) - if(NOT TARGET gtest_main) - FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG v1.17.0 - ) - FetchContent_MakeAvailable(googletest) - endif() - include(GoogleTest) -endif() - -# --------------------------------------------------------------------------- -# Unit tests -# --------------------------------------------------------------------------- -if({{PROJECT_NAME_UPPER}}_TESTS) - enable_testing() - - add_executable(unittests - src/tests/unittests.cpp) - target_include_directories(unittests - PUBLIC include) - target_link_libraries(unittests - gtest_main - ${{{PROJECT_NAME_UPPER}}_DIAGNOSTICS_LIBS}) - gtest_discover_tests(unittests) -endif() - -# --------------------------------------------------------------------------- -# Benchmarks -# --------------------------------------------------------------------------- -if({{PROJECT_NAME_UPPER}}_BENCHMARKS) - add_executable(benchmarks - src/benchmarks/benchmarks.cpp) - target_include_directories(benchmarks - PUBLIC include) - target_link_libraries(benchmarks - benchmark - benchmark_main - ${{{PROJECT_NAME_UPPER}}_DIAGNOSTICS_LIBS}) -endif() - -# --------------------------------------------------------------------------- -# Documentation (Doxygen) -# --------------------------------------------------------------------------- -if({{PROJECT_NAME_UPPER}}_DOCS) - find_package(Doxygen REQUIRED) - - FetchContent_Declare( - doxygen-awesome-css - URL https://github.com/jothepro/doxygen-awesome-css/archive/refs/heads/main.zip - ) - FetchContent_MakeAvailable(doxygen-awesome-css) - - FetchContent_GetProperties(doxygen-awesome-css SOURCE_DIR AWESOME_CSS_DIR) - - set(DOXYFILE_IN ${CMAKE_CURRENT_SOURCE_DIR}/src/docs/Doxyfile.in) - set(DOXYFILE_OUT ${CMAKE_CURRENT_BINARY_DIR}/docs/Doxyfile) - configure_file(${DOXYFILE_IN} ${DOXYFILE_OUT} @ONLY) - - add_custom_target(docs - COMMAND ${DOXYGEN_EXECUTABLE} ${DOXYFILE_OUT} - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} - COMMENT "Generating API documentation with Doxygen" - VERBATIM) -endif() -""" - -CMAKE_PRESETS_JSON = """{ - "version": 4, - "cmakeMinimumRequired": { - "major": 3, - "minor": 18, - "patch": 0 - }, - "configurePresets": [ - { - "name": "base", - "hidden": true, - "cacheVariables": { - "CMAKE_EXPORT_COMPILE_COMMANDS": "ON" - } - }, - { - "name": "debug", - "displayName": "Debug", - "inherits": "base", - "binaryDir": "${sourceDir}/build/debug", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug" - } - }, - { - "name": "release", - "displayName": "Release", - "inherits": "base", - "binaryDir": "${sourceDir}/build/release", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release" - } - }, - { - "name": "benchmarks", - "displayName": "Benchmarks", - "inherits": "base", - "binaryDir": "${sourceDir}/build/benchmarks", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "{{PROJECT_NAME_UPPER}}_BENCHMARKS": "ON" - } - }, - { - "name": "benchmarks-diagnostic", - "displayName": "Benchmarks diagnostic build", - "inherits": "base", - "binaryDir": "${sourceDir}/build/release-with-deb", - "cacheVariables": { - "BENCHMARK_ENABLE_LIBPFM": "ON", - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "{{PROJECT_NAME_UPPER}}_DIAGNOSTICS": "ON", - "{{PROJECT_NAME_UPPER}}_BENCHMARKS": "ON" - } - }, - { - "name": "docs", - "displayName": "Docs", - "inherits": "base", - "binaryDir": "${sourceDir}/build/docs", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "{{PROJECT_NAME_UPPER}}_DOCS": "ON" - } - }, - { - "name": "coverage", - "displayName": "Coverage", - "inherits": "base", - "binaryDir": "${sourceDir}/build/coverage", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "{{PROJECT_NAME_UPPER}}_BENCHMARKS": "OFF", - "{{PROJECT_NAME_UPPER}}_COVERAGE": "ON" - } - }, - { - "name": "asan", - "displayName": "AddressSanitizer", - "inherits": "base", - "binaryDir": "${sourceDir}/build/asan", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "{{PROJECT_NAME_UPPER}}_BENCHMARKS": "OFF", - "ENABLE_ADDRESS_SANITIZER": "ON" - } - } - ], - "buildPresets": [ - { - "name": "debug", - "displayName": "Build Debug", - "configurePreset": "debug" - }, - { - "name": "release", - "displayName": "Build Release", - "configurePreset": "release" - }, - { - "name": "benchmarks", - "displayName": "Build Benchmarks", - "configurePreset": "benchmarks" - }, - { - "name": "benchmarks-diagnostic", - "displayName": "Benchmarks diagnostic", - "configurePreset": "benchmarks-diagnostic" - }, - { - "name": "docs", - "displayName": "Build Docs", - "configurePreset": "docs", - "targets": [ - "docs" - ] - }, - { - "name": "coverage", - "displayName": "Build Coverage", - "configurePreset": "coverage" - }, - { - "name": "asan", - "displayName": "Build AddressSanitizer", - "configurePreset": "asan" - } - ] -} -""" - -CLANG_FORMAT = """# Defines the Chromium style for automatic reformatting. -# http://clang.llvm.org/docs/ClangFormatStyleOptions.html -BasedOnStyle: Chromium -# This defaults to 'Auto'. Explicitly set it for a while, so that -# 'vector >' in existing files gets formatted to -# 'vector>'. ('Auto' means that clang-format will only use -# 'int>>' if the file already contains at least one such instance.) -Standard: Cpp11 - -# TODO(crbug.com/1392808): Remove when InsertBraces has been upstreamed into -# the Chromium style (is implied by BasedOnStyle: Chromium). -InsertBraces: true -InsertNewlineAtEOF: true - -# Sort #includes by following -# https://google.github.io/styleguide/cppguide.html#Names_and_Order_of_Includes -IncludeBlocks: Regroup -IncludeCategories: - # C system headers. - - Regex: '^<.*\\.h>' - Priority: 1 - # C++ standard library headers. - - Regex: '^<.*>' - Priority: 2 - # Project headers (quoted includes). - - Regex: '^".*"' - Priority: 3 - # Other libraries. - - Regex: '.*' - Priority: 4 -""" - -GITIGNORE = """build/ -.vscode/ -Testing/ -plans/* -venv/ -docs/* -src/docs/presentations/* -CMakeUserPresets.json -_deps/ -*.gcda -*.gcno -*.gcov -""" - -README_MD = """# {{PROJECT_NAME}} - -{{PROJECT_NAME}} is a C++20 header-only library. - -## Build - -```bash -cmake --preset release -cmake --build --preset release -j -./build/release/unittests -``` -""" - -AGENTS_MD = """# AGENTS.md - AI Coding Assistant Guidelines for {{PROJECT_NAME}} - -## Project Overview - -{{PROJECT_NAME}} is a **C++20 header-only library**. It provides [TODO: brief description]. - -## Skills - -./.kilo/skills/ contains project-specific skills, use them when appropriate. - -## Architecture - -### Project Layout Conventions - -- **`include/`**: Header-only library API (all implementations here, no `.cpp` files) -- **`src/*_tests.cpp`**: Unit tests (Google Test) -- **`src/*_benchmarks.cpp`**: Performance benchmarks (Google Benchmark) -- **`src/docs/`**: Doxygen configuration - -### Key Design Decisions - -1. **Header-only library**: All code in `include/`; no compiled library. -2. **Non-owning spans**: Use `std::span` for external data where appropriate. -3. **SIMD conditional compilation**: Use `#ifdef {{PROJECT_NAME_UPPER}}_AVX512_SUPPORT` / `{{PROJECT_NAME_UPPER}}_AVX2_SUPPORT` with scalar fallbacks. -4. **Target domain**: Optimized for practical data sizes. -5. **Platform**: Linux/Unix is the primary target platform. - -### Why Header-Only? - -- **SIMD flexibility**: Users compile with their target `-march` flags. -- **Better inlining**: Compiler sees full implementation. -- **No ABI issues**: Works across compilers and standard library versions. -- **Easy integration**: Users just `#include` headers. -- **Template-friendly**: No explicit instantiation needed. - -## Technology Stack - -- **Language**: C++20 (required features: `std::span`, `std::popcount`, ``) -- **Build**: CMake >= 3.18 -- **Testing**: Google Test v1.17.0 -- **Benchmarking**: Google Benchmark v1.9.4 -- **SIMD**: AVX-512 (primary), AVX2 (fallback), scalar fallbacks -- **Style**: Chromium C++ style (`.clang-format`) - -### Dependencies - -The library itself is header-only and has **no runtime dependencies**. Build-time dependencies are managed via CMake FetchContent and controlled by options: - -| Option | Default | What it enables | -|--------|---------|-----------------| -| `{{PROJECT_NAME_UPPER}}_TESTS` | `ON` | Unit tests (fetches Google Test) | -| `{{PROJECT_NAME_UPPER}}_BENCHMARKS` | `OFF` | Benchmarks (fetches Google Benchmark) | - -## Build Commands - -```bash -# Standard build (Release) -cmake -B build/release -DCMAKE_BUILD_TYPE=Release -cmake --build build/release -j - -# Debug build -cmake -B build/debug -DCMAKE_BUILD_TYPE=Debug -cmake --build build/debug -j - -# Without AVX-512 (AVX2 fallback) -cmake -B build/release -DDISABLE_AVX512=ON -cmake --build build/release -j - -# With AddressSanitizer -cmake -B build/asan -DENABLE_ADDRESS_SANITIZER=ON -cmake --build build/asan -j - -# Custom march flag -cmake -B build/release -DMARCH=icelake-client -cmake --build build/release -j - -# Tests only (no benchmarks) -cmake -B build/release -D{{PROJECT_NAME_UPPER}}_BENCHMARKS=OFF -cmake --build build/release -j -``` - -## Testing - -### Running Tests - -```bash -./build/release/unittests -``` - -### Testing Patterns - -- **Differential testing**: Compare against naive reference implementations. -- **Randomized testing**: Random inputs with configurable seed. -- **Exhaustive short inputs**: Test all patterns for small sizes. - -## Code Style Guidelines - -1. **Formatting**: Run `clang-format` before committing (Chromium style) -2. **Namespace**: All library code in `{{NAMESPACE}}` namespace -3. **Documentation**: Use Doxygen-style comments for public API -4. **Constants**: Use `constexpr` for compile-time values -5. **Alignment**: Be aware of data alignment; prefer 64-byte aligned array allocations where performance matters - -## CI/CD Workflows - -- **build-test.yml**: Builds and runs tests with AddressSanitizer -- **linter.yml**: Clang-format checks on all C/C++ files -- **coverage.yml**: Coverage reporting with codecov upload -- **doxygen.yml**: Documentation generation and GitHub Pages deployment - -## Common Tasks for AI Agents - -### Adding a New Component - -1. Create header in `include/{{NAMESPACE}}/` with Doxygen documentation -2. Add unit tests in `src/tests/_tests.cpp` -3. Add benchmarks in `src/benchmarks/_benchmarks.cpp` -4. Update `CMakeLists.txt` with new executables -5. Run `clang-format` on new files - -### Modifying SIMD Code - -1. Provide implementations for: - - AVX-512 (`#ifdef {{PROJECT_NAME_UPPER}}_AVX512_SUPPORT`) - - AVX2 (`#ifdef {{PROJECT_NAME_UPPER}}_AVX2_SUPPORT`) - - Scalar fallback -2. Test with `-DDISABLE_AVX512=ON` to verify fallback works -3. Benchmark to ensure performance is maintained - -### Adding Tests - -1. Use Google Test framework -2. Include naive reference implementation for differential testing -3. Add edge cases: empty input, single element, boundary conditions -4. Use random testing with configurable seed for reproducibility - -## Performance Philosophy - -- **Goal**: Best practical performance (not just asymptotic complexity) -- **Approach**: Benchmark-driven optimization using Google Benchmark -- **SIMD**: Leverage vectorized operations where beneficial -- **Cache efficiency**: Align data structures to cache line boundaries (64 bytes) -""" - -DOXYFILE_IN = """# Doxyfile - -DOXYFILE_ENCODING = UTF-8 -PROJECT_NAME = "{{PROJECT_NAME}}" -PROJECT_NUMBER = -PROJECT_BRIEF = -PROJECT_LOGO = -PROJECT_ICON = -OUTPUT_DIRECTORY = docs -CREATE_SUBDIRS = NO -CREATE_SUBDIRS_LEVEL = 8 -ALLOW_UNICODE_NAMES = NO -OUTPUT_LANGUAGE = English -BRIEF_MEMBER_DESC = YES -REPEAT_BRIEF = YES -ABBREVIATE_BRIEF = "The $name class" \ - "The $name widget" \ - "The $name file" \ - is \ - provides \ - specifies \ - contains \ - represents \ - a \ - an \ - the -ALWAYS_DETAILED_SEC = NO -INLINE_INHERITED_MEMB = NO -FULL_PATH_NAMES = YES -STRIP_FROM_PATH = @CMAKE_CURRENT_SOURCE_DIR@ -STRIP_FROM_INC_PATH = -SHORT_NAMES = NO -JAVADOC_AUTOBRIEF = NO -JAVADOC_BANNER = NO -QT_AUTOBRIEF = NO -MULTILINE_CPP_IS_BRIEF = NO -PYTHON_DOCSTRING = YES -INHERIT_DOCS = YES -SEPARATE_MEMBER_PAGES = NO -TAB_SIZE = 4 -ALIASES = -OPTIMIZE_OUTPUT_FOR_C = NO -OPTIMIZE_OUTPUT_JAVA = NO -OPTIMIZE_FOR_FORTRAN = NO -OPTIMIZE_OUTPUT_VHDL = NO -OPTIMIZE_OUTPUT_SLICE = NO -EXTENSION_MAPPING = -MARKDOWN_SUPPORT = YES -MARKDOWN_STRICT = YES -TOC_INCLUDE_HEADINGS = 6 -MARKDOWN_ID_STYLE = DOXYGEN -AUTOLINK_SUPPORT = YES -AUTOLINK_IGNORE_WORDS = -BUILTIN_STL_SUPPORT = NO -CPP_CLI_SUPPORT = NO -SIP_SUPPORT = NO -IDL_PROPERTY_SUPPORT = YES -DISTRIBUTE_GROUP_DOC = NO -GROUP_NESTED_COMPOUNDS = NO -SUBGROUPING = YES -INLINE_GROUPED_CLASSES = NO -INLINE_SIMPLE_STRUCTS = NO -TYPEDEF_HIDES_STRUCT = NO -LOOKUP_CACHE_SIZE = 0 -NUM_PROC_THREADS = 1 -TIMESTAMP = NO -EXTRACT_ALL = NO -EXTRACT_PRIVATE = NO -EXTRACT_PRIV_VIRTUAL = NO -EXTRACT_PACKAGE = NO -EXTRACT_STATIC = NO -EXTRACT_LOCAL_CLASSES = YES -EXTRACT_LOCAL_METHODS = NO -EXTRACT_ANON_NSPACES = NO -RESOLVE_UNNAMED_PARAMS = YES -HIDE_UNDOC_MEMBERS = NO -HIDE_UNDOC_CLASSES = NO -HIDE_UNDOC_NAMESPACES = YES -HIDE_FRIEND_COMPOUNDS = NO -HIDE_IN_BODY_DOCS = NO -INTERNAL_DOCS = NO -CASE_SENSE_NAMES = SYSTEM -HIDE_SCOPE_NAMES = NO -HIDE_COMPOUND_REFERENCE= NO -SHOW_HEADERFILE = YES -SHOW_INCLUDE_FILES = YES -SHOW_GROUPED_MEMB_INC = NO -FORCE_LOCAL_INCLUDES = NO -INLINE_INFO = YES -SORT_MEMBER_DOCS = YES -SORT_BRIEF_DOCS = NO -SORT_MEMBERS_CTORS_1ST = NO -SORT_GROUP_NAMES = NO -SORT_BY_SCOPE_NAME = NO -STRICT_PROTO_MATCHING = NO -GENERATE_TODOLIST = YES -GENERATE_TESTLIST = YES -GENERATE_BUGLIST = YES -GENERATE_DEPRECATEDLIST= YES -ENABLED_SECTIONS = -MAX_INITIALIZER_LINES = 30 -SHOW_USED_FILES = YES -SHOW_FILES = YES -SHOW_NAMESPACES = YES -FILE_VERSION_FILTER = -LAYOUT_FILE = -CITE_BIB_FILES = -EXTERNAL_TOOL_PATH = -QUIET = NO -WARNINGS = YES -WARN_IF_UNDOCUMENTED = YES -WARN_IF_DOC_ERROR = YES -WARN_IF_INCOMPLETE_DOC = YES -WARN_NO_PARAMDOC = NO -WARN_IF_UNDOC_ENUM_VAL = NO -WARN_LAYOUT_FILE = YES -WARN_AS_ERROR = NO -WARN_FORMAT = "$file:$line: $text" -WARN_LINE_FORMAT = "at line $line of file $file" -WARN_LOGFILE = -INPUT = @CMAKE_CURRENT_SOURCE_DIR@/include \ - @CMAKE_CURRENT_SOURCE_DIR@/README.md -INPUT_ENCODING = UTF-8 -INPUT_FILE_ENCODING = -FILE_PATTERNS = *.c \ - *.cc \ - *.cxx \ - *.cpp \ - *.h \ - *.hh \ - *.hxx \ - *.hpp -RECURSIVE = YES -EXCLUDE = -EXCLUDE_SYMLINKS = NO -EXCLUDE_PATTERNS = -EXCLUDE_SYMBOLS = -EXAMPLE_PATH = -EXAMPLE_PATTERNS = * -EXAMPLE_RECURSIVE = NO -IMAGE_PATH = @CMAKE_CURRENT_SOURCE_DIR@/src/docs/images -INPUT_FILTER = -FILTER_PATTERNS = -FILTER_SOURCE_FILES = NO -FILTER_SOURCE_PATTERNS = -USE_MDFILE_AS_MAINPAGE = @CMAKE_CURRENT_SOURCE_DIR@/README.md -IMPLICIT_DIR_DOCS = YES -FORTRAN_COMMENT_AFTER = 72 -SOURCE_BROWSER = NO -INLINE_SOURCES = NO -STRIP_CODE_COMMENTS = YES -REFERENCED_BY_RELATION = NO -REFERENCES_RELATION = NO -REFERENCES_LINK_SOURCE = YES -SOURCE_TOOLTIPS = YES -USE_HTAGS = NO -VERBATIM_HEADERS = YES -CLANG_ASSISTED_PARSING = NO -CLANG_ADD_INC_PATHS = YES -CLANG_OPTIONS = -CLANG_DATABASE_PATH = -ALPHABETICAL_INDEX = YES -IGNORE_PREFIX = -GENERATE_HTML = YES -HTML_OUTPUT = html -HTML_FILE_EXTENSION = .html -HTML_HEADER = -HTML_FOOTER = -HTML_STYLESHEET = -HTML_EXTRA_STYLESHEET = @AWESOME_CSS_DIR@/doxygen-awesome.css -HTML_EXTRA_FILES = -HTML_COLORSTYLE = AUTO_LIGHT -HTML_COLORSTYLE_HUE = 220 -HTML_COLORSTYLE_SAT = 100 -HTML_COLORSTYLE_GAMMA = 80 -HTML_DYNAMIC_MENUS = YES -HTML_DYNAMIC_SECTIONS = NO -HTML_CODE_FOLDING = YES -HTML_COPY_CLIPBOARD = YES -HTML_PROJECT_COOKIE = -HTML_INDEX_NUM_ENTRIES = 100 -GENERATE_DOCSET = NO -DOCSET_FEEDNAME = "Doxygen generated docs" -DOCSET_FEEDURL = -DOCSET_BUNDLE_ID = org.doxygen.Project -DOCSET_PUBLISHER_ID = org.doxygen.Publisher -DOCSET_PUBLISHER_NAME = Publisher -GENERATE_HTMLHELP = NO -CHM_FILE = -HHC_LOCATION = -GENERATE_CHI = NO -CHM_INDEX_ENCODING = -BINARY_TOC = NO -TOC_EXPAND = NO -SITEMAP_URL = -GENERATE_QHP = NO -QCH_FILE = -QHP_NAMESPACE = org.doxygen.Project -QHP_VIRTUAL_FOLDER = doc -QHP_CUST_FILTER_NAME = -QHP_CUST_FILTER_ATTRS = -QHP_SECT_FILTER_ATTRS = -QHG_LOCATION = -GENERATE_ECLIPSEHELP = NO -ECLIPSE_DOC_ID = org.doxygen.Project -DISABLE_INDEX = NO -GENERATE_TREEVIEW = YES -PAGE_OUTLINE_PANEL = YES -FULL_SIDEBAR = NO -ENUM_VALUES_PER_LINE = 4 -SHOW_ENUM_VALUES = NO -TREEVIEW_WIDTH = 250 -EXT_LINKS_IN_WINDOW = NO -OBFUSCATE_EMAILS = YES -HTML_FORMULA_FORMAT = png -FORMULA_FONTSIZE = 10 -FORMULA_MACROFILE = -USE_MATHJAX = NO -MATHJAX_VERSION = MathJax_2 -MATHJAX_FORMAT = HTML-CSS -MATHJAX_RELPATH = -MATHJAX_EXTENSIONS = -MATHJAX_CODEFILE = -SEARCHENGINE = YES -SERVER_BASED_SEARCH = NO -EXTERNAL_SEARCH = NO -SEARCHENGINE_URL = -SEARCHDATA_FILE = searchdata.xml -EXTERNAL_SEARCH_ID = -EXTRA_SEARCH_MAPPINGS = -GENERATE_LATEX = NO -""" - -COVERAGE_REPORT_SH = """#!/usr/bin/env bash -set -euo pipefail - -ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" -BUILD_DIR="${ROOT_DIR}/build/coverage" - -cmake --preset coverage -cmake --build --preset coverage - -"${BUILD_DIR}/unittests" - -cd "${BUILD_DIR}" -find . -name "*.gcda" > gcov_files.txt -while read -r f; do - case "${f}" in - *"/_deps/"*|*"/third_party/"*|*"/src/benchmarks/"*) - ;; - *) - gcov -pb "${f}" >> coverage.txt - ;; - esac -done < gcov_files.txt -echo "gcov report written to ${BUILD_DIR}/coverage.txt" -""" - -BUILD_TEST_YML = """name: Tests (ASan) - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -jobs: - build-and-test: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Create Build Directory - run: mkdir build - - - name: Configure CMake - working-directory: ./build - run: cmake -DDISABLE_AVX512=ON -DENABLE_ADDRESS_SANITIZER=ON -D{{PROJECT_NAME_UPPER}}_BENCHMARKS=OFF .. - - - name: Build Project - working-directory: ./build - run: make -j - - - name: Run Unittests - working-directory: ./build - run: ./unittests -""" - -LINTER_YML = """name: Clang Format Lint - -on: - pull_request: - push: - branches: [main] - -jobs: - clang-format: - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Install clang-format - run: sudo apt-get update && sudo apt-get install -y clang-format - - - name: Run clang-format check - run: | - mapfile -t FILES < <(find include src -type f \\( -name '*.cpp' -o -name '*.hpp' -o -name '*.cc' -o -name '*.c' -o -name '*.h' \\)) - clang-format --version - if [ ${#FILES[@]} -eq 0 ]; then - echo "No C/C++ files found." - exit 0 - fi - - clang-format --dry-run --Werror "${FILES[@]}" -""" - -COVERAGE_YML = """name: coverage - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -jobs: - coverage: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Create Build Directory - run: mkdir build - - - name: Run coverage - run: ./scripts/coverage_report.sh - - - name: Upload to Codecov - uses: codecov/codecov-action@v4 - with: - token: ${{ secrets.CODECOV_TOKEN }} - files: build/coverage/coverage.txt - flags: gcov - fail_ci_if_error: false - - - name: Upload coverage artifacts - uses: actions/upload-artifact@v4 - with: - name: coverage-gcov - path: | - build/coverage/coverage.txt - build/coverage/*.gcov -""" - -DOXYGEN_YML = """# Simple workflow for deploying static content to GitHub Pages -name: Deploy static content to Pages - -on: - # Runs on pushes targeting the default branch - push: - branches: ["main"] - - # Allows you to run this workflow manually from the Actions tab - workflow_dispatch: - -# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages -permissions: - contents: read - pages: write - id-token: write - -# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. -# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. -concurrency: - group: "pages" - cancel-in-progress: false - -jobs: - # Single deploy job since we're just deploying - deploy: - environment: - name: github-pages - url: ${{ steps.deployment.outputs.page_url }} - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Install Doxygen v1.13.2 - run: | - transformed_version=$(echo "1.13.2" | tr '.' '_') - wget https://github.com/doxygen/doxygen/releases/download/Release_${transformed_version}/doxygen-1.13.2.linux.bin.tar.gz - tar -xzf doxygen-1.13.2.linux.bin.tar.gz - sudo mv doxygen-1.13.2/bin/doxygen /usr/local/bin/doxygen - shell: bash - - name: Cmake configure - run: cmake -S ${{github.workspace}} -B ${{github.workspace}}/build -D{{PROJECT_NAME_UPPER}}_DOCS=ON -D{{PROJECT_NAME_UPPER}}_TESTS=OFF -D{{PROJECT_NAME_UPPER}}_BENCHMARKS=OFF - - name: Build docs - run: cmake --build ${{github.workspace}}/build --target docs - - name: Upload artifact - uses: actions/upload-pages-artifact@v3 - with: - # Upload entire repository - path: ${{github.workspace}}/build/docs/html - - name: Deploy to GitHub Pages - id: deployment - uses: actions/deploy-pages@v4 -""" - -HEADER_HPP = """#pragma once - -/** - * @file {{HEADER_NAME}} - * @brief Main header for the {{PROJECT_NAME}} library - */ - -namespace {{NAMESPACE}} { - -/** - * @brief Example function. - * - * TODO: Replace with actual library functionality. - */ -inline int example() { - return 42; -} - -} // namespace {{NAMESPACE}} -""" - -UNITTESTS_CPP = """#include - -#include "{{NAMESPACE}}/{{HEADER_NAME}}" - -TEST(ExampleTest, BasicAssertion) { - EXPECT_EQ({{NAMESPACE}}::example(), 42); -} -""" - -BENCHMARKS_CPP = """#include - -#include "{{NAMESPACE}}/{{HEADER_NAME}}" - -static void BM_Example(benchmark::State& state) { - for (auto _ : state) { - benchmark::DoNotOptimize({{NAMESPACE}}::example()); - } -} - -BENCHMARK(BM_Example); - -BENCHMARK_MAIN(); -""" - - -# --------------------------------------------------------------------------- -# Generation logic -# --------------------------------------------------------------------------- - -def generate(args: argparse.Namespace) -> None: - project_name = args.name - namespace = args.namespace or project_name.replace("-", "") - project_name_upper = to_upper(project_name) - header_name = f"{to_snake(project_name)}.hpp" - output_dir = Path(args.output_dir).resolve() / project_name - - if output_dir.exists(): - print(f"Error: output directory already exists: {output_dir}") - sys.exit(1) - - substitutions = { - "{{PROJECT_NAME}}": project_name, - "{{NAMESPACE}}": namespace, - "{{PROJECT_NAME_UPPER}}": project_name_upper, - "{{HEADER_NAME}}": header_name, - } - - def sub(text: str) -> str: - for key, value in substitutions.items(): - text = text.replace(key, value) - return text - - # Create directories - (output_dir / "include" / namespace).mkdir(parents=True) - (output_dir / "src" / "tests").mkdir(parents=True) - (output_dir / "src" / "benchmarks").mkdir(parents=True) - (output_dir / "src" / "docs").mkdir(parents=True) - (output_dir / "src" / "docs" / "images").mkdir(parents=True) - (output_dir / "scripts").mkdir(parents=True) - (output_dir / ".github" / "workflows").mkdir(parents=True) - - # Write files - files = { - output_dir / "CMakeLists.txt": sub(CMAKE_LISTS_TXT), - output_dir / "CMakePresets.json": sub(CMAKE_PRESETS_JSON), - output_dir / ".clang-format": sub(CLANG_FORMAT), - output_dir / ".gitignore": sub(GITIGNORE), - output_dir / "README.md": sub(README_MD), - output_dir / "AGENTS.md": sub(AGENTS_MD), - output_dir / "src" / "docs" / "Doxyfile.in": sub(DOXYFILE_IN), - output_dir / "scripts" / "coverage_report.sh": sub(COVERAGE_REPORT_SH), - output_dir / ".github" / "workflows" / "build-test.yml": sub(BUILD_TEST_YML), - output_dir / ".github" / "workflows" / "linter.yml": sub(LINTER_YML), - output_dir / ".github" / "workflows" / "coverage.yml": sub(COVERAGE_YML), - output_dir / ".github" / "workflows" / "doxygen.yml": sub(DOXYGEN_YML), - output_dir / "include" / namespace / header_name: sub(HEADER_HPP), - output_dir / "src" / "tests" / "unittests.cpp": sub(UNITTESTS_CPP), - output_dir / "src" / "benchmarks" / "benchmarks.cpp": sub(BENCHMARKS_CPP), - } - - for path, content in files.items(): - path.write_text(content) - print(f"Created: {path.relative_to(output_dir.parent)}") - - # Make coverage script executable - (output_dir / "scripts" / "coverage_report.sh").chmod(0o755) - - print(f"\\nProject '{project_name}' generated successfully at {output_dir}") - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Scaffold a new C++20 repository following Pixie conventions." - ) - parser.add_argument("--name", required=True, help="Project name (e.g., my-lib)") - parser.add_argument( - "--namespace", - help="C++ namespace (defaults to project name with hyphens removed)", - ) - parser.add_argument( - "--output-dir", - default=".", - help="Output directory (default: current directory)", - ) - args = parser.parse_args() - generate(args) - - -if __name__ == "__main__": - main() diff --git a/CMakeLists.txt b/CMakeLists.txt index e5eeaec..e648102 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -201,6 +201,14 @@ if(PIXIE_BENCHMARKS) benchmark ${PIXIE_DIAGNOSTICS_LIBS}) + add_executable(bench_rmm_btree + src/benchmarks/bench_rmm_btree.cpp) + target_include_directories(bench_rmm_btree + PUBLIC include) + target_link_libraries(bench_rmm_btree + benchmark + ${PIXIE_DIAGNOSTICS_LIBS}) + if(PIXIE_THIRD_PARTY_BACKENDS) add_executable(bench_rmm_sdsl src/benchmarks/bench_rmm_sdsl.cpp) diff --git a/include/pixie/bits.h b/include/pixie/bits.h index 50a8cc5..7d26597 100644 --- a/include/pixie/bits.h +++ b/include/pixie/bits.h @@ -45,6 +45,140 @@ static inline const __m256i mask_first_half = _mm256_setr_epi8( // clang-format on #endif +/** + * @brief Test 16 int16 RmM btree child ranges for a node-local target. + * @details Each lane represents one child summary. The function checks whether + * @p target lies in `[prefix_before[i] + min_excess[i], + * prefix_before[i] + max_excess[i]]`. When @p include_zero_boundary is true, + * it also accepts `target == prefix_before[i]`, which represents the left + * boundary match used by backward search. + * @param prefix_before Exclusive prefix excess before each child. + * @param min_excess Per-child minimum excess relative to the child start. + * @param max_excess Per-child maximum excess relative to the child start. + * @param target Target excess relative to the start of the parent node. + * @param include_zero_boundary Whether to accept child left-boundary matches. + * @return Bit mask with bit `i` set when lane `i` can contain the target. + */ +static inline uint32_t rmm_btree_match_mask_i16x16(const int16_t* prefix_before, + const int16_t* min_excess, + const int16_t* max_excess, + int16_t target, + bool include_zero_boundary) { +#ifdef PIXIE_AVX2_SUPPORT + const __m256i vtarget = _mm256_set1_epi16(target); + const __m256i vprefix = + _mm256_loadu_si256(reinterpret_cast(prefix_before)); + const __m256i vmin = + _mm256_loadu_si256(reinterpret_cast(min_excess)); + const __m256i vmax = + _mm256_loadu_si256(reinterpret_cast(max_excess)); + + const __m256i lower = _mm256_adds_epi16(vprefix, vmin); + const __m256i upper = _mm256_adds_epi16(vprefix, vmax); + const __m256i ge_lower = _mm256_or_si256(_mm256_cmpgt_epi16(vtarget, lower), + _mm256_cmpeq_epi16(vtarget, lower)); + const __m256i le_upper = _mm256_or_si256(_mm256_cmpgt_epi16(upper, vtarget), + _mm256_cmpeq_epi16(upper, vtarget)); + __m256i matched = _mm256_and_si256(ge_lower, le_upper); + if (include_zero_boundary) { + matched = _mm256_or_si256(matched, _mm256_cmpeq_epi16(vtarget, vprefix)); + } + + const uint32_t byte_mask = + static_cast(_mm256_movemask_epi8(matched)); + uint32_t result = 0; + for (size_t lane = 0; lane < 16; ++lane) { + const uint32_t lane_mask = 0x3u << (lane * 2); + if ((byte_mask & lane_mask) == lane_mask) { + result |= uint32_t{1} << lane; + } + } + return result; +#else + uint32_t result = 0; + for (size_t lane = 0; lane < 16; ++lane) { + const int lower = prefix_before[lane] + min_excess[lane]; + const int upper = prefix_before[lane] + max_excess[lane]; + const bool found = (lower <= target && target <= upper) || + (include_zero_boundary && target == prefix_before[lane]); + if (found) { + result |= uint32_t{1} << lane; + } + } + return result; +#endif +} + +/** + * @brief Test 4 int64 RmM btree child ranges for a node-local target. + * @details Each lane represents one child summary. The function subtracts the + * child prefix from @p target to form a child-relative target, then checks it + * against the child's `[min_excess, max_excess]` range. When + * @p include_zero_boundary is true, it also accepts a zero relative target for + * the left-boundary match used by backward search. + * @param prefix_before Exclusive prefix excess before each child. + * @param min_excess Per-child minimum excess relative to the child start. + * @param max_excess Per-child maximum excess relative to the child start. + * @param target Target excess relative to the start of the parent node. + * @param include_zero_boundary Whether to accept child left-boundary matches. + * @return Bit mask with bit `i` set when lane `i` can contain the target. + */ +static inline uint32_t rmm_btree_match_mask_i64x4(const int64_t* prefix_before, + const int64_t* min_excess, + const int64_t* max_excess, + int64_t target, + bool include_zero_boundary) { +#ifdef PIXIE_AVX2_SUPPORT + const __m256i vtarget = _mm256_set1_epi64x(target); + const __m256i vprefix = + _mm256_loadu_si256(reinterpret_cast(prefix_before)); + const __m256i vmin = + _mm256_loadu_si256(reinterpret_cast(min_excess)); + const __m256i vmax = + _mm256_loadu_si256(reinterpret_cast(max_excess)); + + const __m256i relative = _mm256_sub_epi64(vtarget, vprefix); + const __m256i ge_min = _mm256_or_si256(_mm256_cmpgt_epi64(relative, vmin), + _mm256_cmpeq_epi64(relative, vmin)); + const __m256i le_max = _mm256_or_si256(_mm256_cmpgt_epi64(vmax, relative), + _mm256_cmpeq_epi64(vmax, relative)); + __m256i matched = _mm256_and_si256(ge_min, le_max); + if (include_zero_boundary) { + matched = _mm256_or_si256(matched, _mm256_cmpeq_epi64(vtarget, vprefix)); + } + + const uint32_t byte_mask = + static_cast(_mm256_movemask_epi8(matched)); + uint32_t result = 0; + for (size_t lane = 0; lane < 4; ++lane) { + const uint32_t lane_mask = 0xffu << (lane * 8); + if ((byte_mask & lane_mask) == lane_mask) { + result |= uint32_t{1} << lane; + } + } + return result; +#else + uint32_t result = 0; + for (size_t lane = 0; lane < 4; ++lane) { + const int64_t relative = target - prefix_before[lane]; + const bool found = + (min_excess[lane] <= relative && relative <= max_excess[lane]) || + (include_zero_boundary && relative == 0); + if (found) { + result |= uint32_t{1} << lane; + } + } + return result; +#endif +} + +/** + * @brief Return a mask with the lowest @p num bits set. + * @details Values greater than or equal to 64 produce an all-ones mask, which + * avoids undefined behavior from shifting by the word width. + * @param num Number of low bits to set. + * @return A 64-bit mask containing ones in positions `[0, num)`. + */ static inline uint64_t first_bits_mask(size_t num) { return num >= 64 ? UINT64_MAX : ((1llu << num) - 1); } @@ -663,107 +797,106 @@ static inline const __m256i excess_lut_pos2 = _mm256_setr_epi8( -1, 1, 1, 3, -3, -1, -1, 1, -1, 1, 1, 3); +static inline const __m256i excess_lut_pack_multiplier = + _mm256_set1_epi16(0x1001); +static inline const __m256i excess_lut_bit0 = _mm256_set1_epi8(1); +static inline const __m256i excess_lut_bit1 = _mm256_set1_epi8(2); +static inline const __m256i excess_lut_bit2 = _mm256_set1_epi8(4); +static inline const __m256i excess_lut_bit3 = _mm256_set1_epi8(8); +static inline const __m128i excess_lut_nibble_mask = _mm_set1_epi8(0x0F); // clang-format on #endif /** - * @brief Find every prefix whose excess equals target_x in a 512-bit bitstring. + * @brief Find every prefix whose excess equals target_x in a 128-bit bitstring. * - * Excess(i) = 2*popcount(bits[0..i-1]) - i for i in [0..512]. + * Excess(i) = 2*popcount(bits[0..i-1]) - i for i in [0..128]. * Bit (w*64 + b) of out[w] is set iff excess(w*64 + b + 1) == target_x. * I.e. out bit index b corresponds to prefix length (b+1). * - * @param s 8 little-endian uint64_t words (bit 0 of s[0] is the first bit). - * @param target_x Target excess value in [-512, 512]; outside this range out is + * @param s 2 little-endian uint64_t words (bit 0 of s[0] is the first bit). + * @param target_x Target excess value in [-128, 128]; outside this range out is * zeroed. - * @param out 8 uint64_t words receiving the result bitmask. + * @param out 2 uint64_t words receiving the result bitmask. + * @return Total excess change across the 128-bit bitstring. */ -static inline void excess_positions_512(const uint64_t* s, - int target_x, - uint64_t* out) noexcept { - out[0] = out[1] = out[2] = out[3] = 0; - out[4] = out[5] = out[6] = out[7] = 0; - - if (target_x < -512 || target_x > 512) { - return; +static inline int excess_positions_128(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + out[0] = out[1] = 0; + const int block_delta = 2 * (std::popcount(s[0]) + std::popcount(s[1])) - 128; + + if (target_x < -128 || target_x > 128) { + return block_delta; } #ifdef PIXIE_AVX2_SUPPORT - int cur = 0; const __m256i vdelta = excess_lut_delta; const __m256i vpos0 = excess_lut_pos0; const __m256i vpos1 = excess_lut_pos1; const __m256i vpos2 = excess_lut_pos2; - const __m256i vmult = _mm256_set1_epi16(0x1001); - const __m256i vbit0 = _mm256_set1_epi8(1); - const __m256i vbit1 = _mm256_set1_epi8(2); - const __m256i vbit2 = _mm256_set1_epi8(4); - const __m256i vbit3 = _mm256_set1_epi8(8); - const __m128i vnibble_mask = _mm_set1_epi8(0x0F); - - for (int k = 0; k < 4; ++k) { - int block_delta = - 2 * (std::popcount(s[2 * k]) + std::popcount(s[2 * k + 1])) - 128; - - const int d = 2 * target_x - block_delta; - if (d < -128 || d > 128) { - target_x -= block_delta; - continue; - } - __m128i word_vec = _mm_loadu_si128((const __m128i*)&s[2 * k]); - __m128i lo_nibbles = _mm_and_si128(word_vec, vnibble_mask); - __m128i hi_nibbles = - _mm_and_si128(_mm_srli_epi16(word_vec, 4), vnibble_mask); + const __m256i vmult = excess_lut_pack_multiplier; + const __m256i vbit0 = excess_lut_bit0; + const __m256i vbit1 = excess_lut_bit1; + const __m256i vbit2 = excess_lut_bit2; + const __m256i vbit3 = excess_lut_bit3; + const __m128i vnibble_mask = excess_lut_nibble_mask; + + const int d = 2 * target_x - block_delta; + if (d < -128 || d > 128) { + return block_delta; + } - __m128i unpack_lo = _mm_unpacklo_epi8(lo_nibbles, hi_nibbles); - __m128i unpack_hi = _mm_unpackhi_epi8(lo_nibbles, hi_nibbles); + __m128i word_vec = _mm_loadu_si128((const __m128i*)s); + __m128i lo_nibbles = _mm_and_si128(word_vec, vnibble_mask); + __m128i hi_nibbles = _mm_and_si128(_mm_srli_epi16(word_vec, 4), vnibble_mask); - __m256i nibbles = _mm256_inserti128_si256(_mm256_castsi128_si256(unpack_lo), - unpack_hi, 1); + __m128i unpack_lo = _mm_unpacklo_epi8(lo_nibbles, hi_nibbles); + __m128i unpack_hi = _mm_unpackhi_epi8(lo_nibbles, hi_nibbles); - __m256i ps = _mm256_shuffle_epi8(vdelta, nibbles); - ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 1)); - ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 2)); - ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 4)); - ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 8)); + __m256i nibbles = + _mm256_inserti128_si256(_mm256_castsi128_si256(unpack_lo), unpack_hi, 1); - __m128i ps_lo = _mm256_castsi256_si128(ps); - __m128i ps_hi = _mm256_extracti128_si256(ps, 1); - __m128i carry = _mm_set1_epi8((int8_t)_mm_extract_epi8(ps_lo, 15)); - ps_hi = _mm_add_epi8(ps_hi, carry); - ps = _mm256_inserti128_si256(_mm256_castsi128_si256(ps_lo), ps_hi, 1); + __m256i ps = _mm256_shuffle_epi8(vdelta, nibbles); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 1)); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 2)); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 4)); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 8)); - __m256i b = _mm256_permute2x128_si256(ps, ps, 0x08); - __m256i excl_ps = _mm256_alignr_epi8(ps, b, 15); + __m128i ps_lo = _mm256_castsi256_si128(ps); + __m128i ps_hi = _mm256_extracti128_si256(ps, 1); + __m128i carry = _mm_set1_epi8((int8_t)_mm_extract_epi8(ps_lo, 15)); + ps_hi = _mm_add_epi8(ps_hi, carry); + ps = _mm256_inserti128_si256(_mm256_castsi128_si256(ps_lo), ps_hi, 1); - __m256i vtgt = _mm256_set1_epi8((int8_t)target_x); - __m256i t = _mm256_sub_epi8(vtgt, excl_ps); + __m256i b = _mm256_permute2x128_si256(ps, ps, 0x08); + __m256i excl_ps = _mm256_alignr_epi8(ps, b, 15); - __m256i cmp0 = _mm256_cmpeq_epi8(_mm256_shuffle_epi8(vpos0, nibbles), t); - __m256i cmp1 = _mm256_cmpeq_epi8(_mm256_shuffle_epi8(vpos1, nibbles), t); - __m256i cmp2 = _mm256_cmpeq_epi8(_mm256_shuffle_epi8(vpos2, nibbles), t); - __m256i cmp3 = _mm256_cmpeq_epi8(ps, vtgt); + __m256i vtgt = _mm256_set1_epi8((int8_t)target_x); + __m256i t = _mm256_sub_epi8(vtgt, excl_ps); - __m256i bit0 = _mm256_and_si256(cmp0, vbit0); - __m256i bit1 = _mm256_and_si256(cmp1, vbit1); - __m256i bit2 = _mm256_and_si256(cmp2, vbit2); - __m256i bit3 = _mm256_and_si256(cmp3, vbit3); + __m256i cmp0 = _mm256_cmpeq_epi8(_mm256_shuffle_epi8(vpos0, nibbles), t); + __m256i cmp1 = _mm256_cmpeq_epi8(_mm256_shuffle_epi8(vpos1, nibbles), t); + __m256i cmp2 = _mm256_cmpeq_epi8(_mm256_shuffle_epi8(vpos2, nibbles), t); + __m256i cmp3 = _mm256_cmpeq_epi8(ps, vtgt); - __m256i total_match = _mm256_or_si256(_mm256_or_si256(bit0, bit1), - _mm256_or_si256(bit2, bit3)); + __m256i bit0 = _mm256_and_si256(cmp0, vbit0); + __m256i bit1 = _mm256_and_si256(cmp1, vbit1); + __m256i bit2 = _mm256_and_si256(cmp2, vbit2); + __m256i bit3 = _mm256_and_si256(cmp3, vbit3); - __m256i res = _mm256_maddubs_epi16(total_match, vmult); - __m128i res_lo = _mm256_castsi256_si128(res); - __m128i res_hi = _mm256_extracti128_si256(res, 1); - __m128i packed = _mm_packus_epi16(res_lo, res_hi); + __m256i total_match = + _mm256_or_si256(_mm256_or_si256(bit0, bit1), _mm256_or_si256(bit2, bit3)); - _mm_storeu_si128((__m128i*)&out[2 * k], packed); + __m256i res = _mm256_maddubs_epi16(total_match, vmult); + __m128i res_lo = _mm256_castsi256_si128(res); + __m128i res_hi = _mm256_extracti128_si256(res, 1); + __m128i packed = _mm_packus_epi16(res_lo, res_hi); - target_x -= block_delta; - } + _mm_storeu_si128((__m128i*)out, packed); #else int cur = 0; - for (size_t i = 0; i < 512; ++i) { + for (size_t i = 0; i < 128; ++i) { const uint64_t w = s[i >> 6]; const int bit = int((w >> (i & 63)) & 1ull); cur += bit ? +1 : -1; @@ -772,6 +905,152 @@ static inline void excess_positions_512(const uint64_t* s, } } #endif + return block_delta; +} + +/** + * @brief Prefix excess in a 128-bit bitstring. + * + * Excess(i) = 2*popcount(bits[0..i-1]) - i for i in [0, 128]. + * + * @param s 2 little-endian uint64_t words (bit 0 of s[0] is the first bit). + * @param end_offset Exclusive prefix boundary, clamped to [0, 128]. + * @return Prefix excess on [0, end_offset). + */ +static inline int prefix_excess_128(const uint64_t* s, + size_t end_offset) noexcept { + end_offset = end_offset > 128 ? 128 : end_offset; + if (end_offset == 0) { + return 0; + } + if (end_offset <= 64) { + const int ones = static_cast(std::popcount( + s[0] & first_bits_mask(static_cast(end_offset)))); + return 2 * ones - static_cast(end_offset); + } + const int ones = static_cast( + std::popcount(s[0]) + + std::popcount(s[1] & + first_bits_mask(static_cast(end_offset - 64)))); + return 2 * ones - static_cast(end_offset); +} + +/** + * @brief Find the first prefix reaching target_x in a 128-bit bitstring. + * + * Searches the prefix excess values represented by excess_positions_128 and + * ignores matches before start_offset. The returned offset is the bit position + * whose inclusive prefix reaches target_x. + * + * @param s 2 little-endian uint64_t words (bit 0 of s[0] is the first bit). + * @param target_x Target excess value relative to the beginning of this + * 128-bit bitstring. + * @param start_offset First bit position eligible for a match, in [0, 128]. + * @param block_excess Optional output for the total excess change across the + * 128-bit bitstring. + * @return Matching bit offset in [0, 127], or 128 if no match exists. + */ +static inline size_t forward_search_128(const uint64_t* s, + int target_x, + size_t start_offset, + int* block_excess = nullptr) noexcept { + uint64_t out[2]; + const int delta = excess_positions_128(s, target_x, out); + if (block_excess != nullptr) { + *block_excess = delta; + } + if (start_offset >= 128) { + return 128; + } + + const size_t first_word = start_offset >> 6; + const size_t first_bit = start_offset & 63; + for (size_t word = first_word; word < 2; ++word) { + uint64_t mask = out[word]; + if (word == first_word && first_bit != 0) { + mask &= ~first_bits_mask(first_bit); + } + if (mask != 0) { + return word * 64 + std::countr_zero(mask); + } + } + return 128; +} + +/** + * @brief Find the last prefix before end_offset reaching target_x in a 128-bit + * bitstring. + * + * Searches prefix boundary positions strictly before end_offset, matching the + * RmM backward-search convention. A return value of 0 is a valid match for the + * chunk-start boundary when target_x is zero. + * + * @param s 2 little-endian uint64_t words (bit 0 of s[0] is the first bit). + * @param target_x Target excess value relative to the beginning of this + * 128-bit bitstring. + * @param end_offset Exclusive right boundary for the search, in [0, 128]. + * @param block_excess Optional output for the total excess change across the + * 128-bit bitstring. + * @return Matching prefix boundary offset in [0, 127], or 128 if no match + * exists. + */ +static inline size_t backward_search_128(const uint64_t* s, + int target_x, + size_t end_offset, + int* block_excess = nullptr) noexcept { + uint64_t out[2]; + const int delta = excess_positions_128(s, target_x, out); + if (block_excess != nullptr) { + *block_excess = delta; + } + if (end_offset == 0) { + return 128; + } + + const size_t max_prefix_length = end_offset - 1; + if (max_prefix_length > 0) { + const size_t last_bit_index = max_prefix_length - 1; + size_t word = last_bit_index >> 6; + const size_t bit_in_word = last_bit_index & 63; + uint64_t mask = out[word] & first_bits_mask(bit_in_word + 1); + while (true) { + if (mask != 0) { + return word * 64 + (63 - std::countl_zero(mask)) + 1; + } + if (word == 0) { + break; + } + --word; + mask = out[word]; + } + } + return target_x == 0 ? 0 : 128; +} + +/** + * @brief Find every prefix whose excess equals target_x in a 512-bit bitstring. + * + * Excess(i) = 2*popcount(bits[0..i-1]) - i for i in [0..512]. + * Bit (w*64 + b) of out[w] is set iff excess(w*64 + b + 1) == target_x. + * I.e. out bit index b corresponds to prefix length (b+1). + * + * @param s 8 little-endian uint64_t words (bit 0 of s[0] is the first bit). + * @param target_x Target excess value in [-512, 512]; outside this range out is + * zeroed. + * @param out 8 uint64_t words receiving the result bitmask. + */ +static inline void excess_positions_512(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + if (target_x < -512 || target_x > 512) { + out[0] = out[1] = out[2] = out[3] = 0; + out[4] = out[5] = out[6] = out[7] = 0; + return; + } + + for (int k = 0; k < 4; ++k) { + target_x -= excess_positions_128(s + 2 * k, target_x, out + 2 * k); + } } /** diff --git a/include/pixie/bitvector.h b/include/pixie/bitvector.h index b486736..8a27553 100644 --- a/include/pixie/bitvector.h +++ b/include/pixie/bitvector.h @@ -1,785 +1,846 @@ -#pragma once - -#include -#include - -#include -#include -#include -#include -#include -#include - -#ifdef PIXIE_DIAGNOSTICS -#include -#endif - -namespace pixie { - -/** - * @brief Non-interleaved, non-owning bit vector with rank and select. - * - * - * @details - * This is a two-level rank/select index for a bit vector stored - * externally as - * 64-bit words. The layout follows ideas from: - * - * {1} - * "SPIDER: Improved Succinct Rank and Select Performance" - * Matthew D. Laws, - * Jocelyn Bliven, Kit Conklin, Elyes Laalai, Samuel McCauley, - * Zach S. - * Sturdevant - * https://github.com/williams-cs/spider - * - * {2} "Engineering - * compact data structures for rank and select queries on - * bit vectors" - * Kurpicz F. - * https://github.com/pasta-toolbox/bit_vector - * - * Structure - * overview: - * - Super blocks of 2^16 bits with 64-bit ranks (~0.98% - * overhead). - * - Basic blocks of 512 bits with 16-bit ranks (~3.125% - * overhead). - * - Select samples every 16384 bits (~0.39% overhead). - * - * - * Rank: 2 table lookups plus SIMD popcount in the 512-bit block. - * - * Select: - - * * - Start from a sampled super block. - * - SIMD linear scan to find the super - * block. - * - SIMD linear scan to find the basic block. - * - * This variant does - * not interleave data and index, favoring simpler scans. - */ -class BitVector { - private: - constexpr static size_t kWordSize = 64; - constexpr static size_t kSuperBlockRankIntSize = 64; - constexpr static size_t kBasicBlockRankIntSize = 16; - constexpr static size_t kBasicBlockSize = 512; - constexpr static size_t kWordsPerBlock = 8; - constexpr static size_t kSuperBlockSize = 65536; - constexpr static size_t kBlocksPerSuperBlock = 128; - constexpr static size_t kSelectSampleFrequency = 16384; - - alignas(64) uint64_t delta_super[8]; - alignas(64) uint16_t delta_basic[32]; - - AlignedStorage super_block_rank_; // 64-bit global prefix sums - AlignedStorage basic_block_rank_; // 16-bit local prefix sums - AlignedStorage select1_samples_; // 64-bit global positions - AlignedStorage select0_samples_; // 64-bit global positions - const size_t num_bits_; - const size_t padded_size_; - size_t max_rank_; - - std::span bits_; - - /** - * @brief Precompute rank for fast queries. - */ - void build_rank() { - size_t num_superblocks = 8 + (padded_size_ - 1) / kSuperBlockSize; - // Add more blocks to ease SIMD processing - // num_basicblocks to fully cover superblock, i.e. 128 - // This reduces branching in select - num_superblocks = ((num_superblocks + 7) / 8) * 8; - size_t num_basicblocks = num_superblocks * kBlocksPerSuperBlock; - super_block_rank_.resize(num_superblocks * 64); - basic_block_rank_.resize(num_basicblocks * 16); - - auto super_block_rank = super_block_rank_.As64BitInts(); - auto basic_block_rank = basic_block_rank_.As16BitInts(); - - uint64_t super_block_sum = 0; - uint16_t basic_block_sum = 0; - - for (size_t i = 0; i / kBasicBlockSize < basic_block_rank.size(); - i += kWordSize) { - if (i % kSuperBlockSize == 0) { - super_block_sum += basic_block_sum; - super_block_rank[i / kSuperBlockSize] = super_block_sum; - basic_block_sum = 0; - } - if (i % kBasicBlockSize == 0) { - basic_block_rank[i / kBasicBlockSize] = basic_block_sum; - } - if (i / kWordSize < bits_.size()) { - basic_block_sum += std::popcount(bits_[i / kWordSize]); - } - } - max_rank_ = super_block_sum + basic_block_sum; - } - - /** - * @brief Calculate select samples. - */ - void build_select() { - uint64_t milestone = kSelectSampleFrequency; - uint64_t milestone0 = kSelectSampleFrequency; - uint64_t rank = 0; - uint64_t rank0 = 0; - - size_t num_one_samples = - 1 + (max_rank_ + kSelectSampleFrequency - 1) / kSelectSampleFrequency; - size_t num_zero_samples = - 1 + (num_bits_ - max_rank_ + kSelectSampleFrequency - 1) / - kSelectSampleFrequency; - - select1_samples_.resize(num_one_samples * 64); - select0_samples_.resize(num_zero_samples * 64); - auto select1_samples = select1_samples_.As64BitInts(); - auto select0_samples = select0_samples_.As64BitInts(); - - select1_samples[0] = 0; - select0_samples[0] = 0; - - size_t num_zeros = 1, num_ones = 1; - - for (size_t i = 0; i < bits_.size(); ++i) { - auto ones = std::popcount(bits_[i]); - auto zeros = 64 - ones; - if (rank + ones >= milestone) { - auto pos = select_64(bits_[i], milestone - rank - 1); - // TODO: try including global rank into select samples to save - // a cache miss on global rank scan - select1_samples[num_ones++] = (64 * i + pos) / kSuperBlockSize; - milestone += kSelectSampleFrequency; - } - if (rank0 + zeros >= milestone0) { - auto pos = select_64(~bits_[i], milestone0 - rank0 - 1); - select0_samples[num_zeros++] = (64 * i + pos) / kSuperBlockSize; - milestone0 += kSelectSampleFrequency; - } - rank += ones; - rank0 += zeros; - } - - for (size_t i = 0; i < 8; ++i) { - delta_super[i] = i * kSuperBlockSize; - } - for (size_t i = 0; i < 32; ++i) { - delta_basic[i] = i * kBasicBlockSize; - } - } - - /** - * @brief First step of the select operation. - * @param rank 1-based - * rank of the 1-bit to locate. - */ - uint64_t find_superblock(uint64_t rank) const { - auto select1_samples = select1_samples_.AsConst64BitInts(); - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - - uint64_t left = select1_samples[rank / kSelectSampleFrequency]; - - while (left + 7 < super_block_rank.size()) { - auto len = lower_bound_8x64(&super_block_rank[left], rank); - if (len < 8) { - return left + len - 1; - } - left += 8; - } - if (left + 3 < super_block_rank.size()) { - auto len = lower_bound_4x64(&super_block_rank[left], rank); - if (len < 4) { - return left + len - 1; - } - left += 4; - } - while (left < super_block_rank.size() && super_block_rank[left] < rank) { - left++; - } - return left - 1; - } - - /** - * @brief First step of the select0 operation. - * @param rank0 1-based - * rank of the 0-bit to locate. - */ - uint64_t find_superblock_zeros(uint64_t rank0) const { - auto select0_samples = select0_samples_.AsConst64BitInts(); - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - - uint64_t left = select0_samples[rank0 / kSelectSampleFrequency]; - - while (left + 7 < super_block_rank.size()) { - auto len = lower_bound_delta_8x64(&super_block_rank[left], rank0, - delta_super, kSuperBlockSize * left); - if (len < 8) { - return left + len - 1; - } - left += 8; - } - if (left + 3 < super_block_rank.size()) { - auto len = lower_bound_delta_4x64(&super_block_rank[left], rank0, - delta_super, kSuperBlockSize * left); - if (len < 4) { - return left + len - 1; - } - left += 4; - } - while (left < super_block_rank.size() && - kSuperBlockSize * left - super_block_rank[left] < rank0) { - left++; - } - return left - 1; - } - - /** - * @brief SIMD-optimized linear scan. - * @param local_rank Rank within - * the super block. - * @param s_block Super block index. - * @details - * - * Processes 32 16-bit entries at once (full cache line), so there is at most - - * * 4 iterations. - */ - uint64_t find_basicblock(uint16_t local_rank, uint64_t s_block) const { - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - for (size_t pos = 0; pos < kBlocksPerSuperBlock; pos += 32) { - auto count = lower_bound_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); - if (count < 32) { - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - } - return kBlocksPerSuperBlock * s_block + kBlocksPerSuperBlock - 1; - } - - /** - * @brief SIMD-optimized linear scan. - * @param local_rank0 Rank of - * zeros within the super block. - * @param s_block Super block index. - * - * @details - * Processes 32 16-bit entries at once (full cache line), so - * there is at most - * 4 iterations. - */ - uint64_t find_basicblock_zeros(uint16_t local_rank0, uint64_t s_block) const { - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - for (size_t pos = 0; pos < kBlocksPerSuperBlock; pos += 32) { - auto count = lower_bound_delta_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, - delta_basic, kBasicBlockSize * pos); - if (count < 32) { - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - } - return kBlocksPerSuperBlock * s_block + kBlocksPerSuperBlock - 1; - } - - /** - * @brief Interpolation search with SIMD optimization. - * @param - * local_rank Rank within the super block. - * @param s_block Super block - * index. - * @details - * Similar to find_basicblock but initial guess is - * based on linear - * interpolation, for random data it should make initial - * guess correct - * most of the times, we start from the 32 wide block with - * interpolation - * guess at the center, if we see that select result lie in - * lower blocks - * we backoff to find_basicblock - */ - uint64_t find_basicblock_is(uint16_t local_rank, uint64_t s_block) const { - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - auto lower = super_block_rank[s_block]; - auto upper = super_block_rank[s_block + 1]; - - uint64_t pos = kBlocksPerSuperBlock * local_rank / (upper - lower); - pos = pos + 16 < 32 ? 0 : (pos - 16); - pos = pos > 96 ? 96 : pos; - while (pos < 96) { - auto count = lower_bound_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); - if (count == 0) { - return find_basicblock(local_rank, s_block); - } - if (count < 32) { - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - pos += 32; - } - pos = 96; - auto count = lower_bound_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); - if (count == 0) { - return find_basicblock(local_rank, s_block); - } - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - - /** - * @brief Interpolation search with SIMD optimization. - * @param - * local_rank0 Rank of zeros within the super block. - * @param s_block Super - * block index. - * @details - * Similar to find_basicblock_zeros but - * initial guess is based on linear - * interpolation, for random data it - * should make initial guess correct - * most of the times, we start from the - * 32 wide block with interpolation - * guess at the center, if we see that - * select result lie in lower blocks - * we backoff to find_basicblock_zeros - - */ - uint64_t find_basicblock_is_zeros(uint16_t local_rank0, - uint64_t s_block) const { - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - auto lower = kSuperBlockSize * s_block - super_block_rank[s_block]; - auto upper = - kSuperBlockSize * (s_block + 1) - super_block_rank[s_block + 1]; - - uint64_t pos = kBlocksPerSuperBlock * local_rank0 / (upper - lower); - pos = pos + 16 < 32 ? 0 : (pos - 16); - pos = pos > 96 ? 96 : pos; - while (pos < 96) { - auto count = lower_bound_delta_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, - delta_basic, kBasicBlockSize * pos); - if (count == 0) { - return find_basicblock_zeros(local_rank0, s_block); - } - if (count < 32) { - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - pos += 32; - } - pos = 96; - auto count = lower_bound_delta_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, - delta_basic, kBasicBlockSize * pos); - if (count == 0) { - return find_basicblock_zeros(local_rank0, s_block); - } - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - - public: -#ifdef PIXIE_DIAGNOSTICS - struct DiagnosticsBytes { - size_t source_bitvector_bytes = 0; - size_t super_block_rank_bytes = 0; - size_t basic_block_rank_bytes = 0; - size_t select1_samples_bytes = 0; - size_t select0_samples_bytes = 0; - size_t total_bytes = 0; - }; - - /** - * @brief Returns the number of bytes used by each internal component. - */ - DiagnosticsBytes diagnostics_bytes() const { - DiagnosticsBytes result; - result.source_bitvector_bytes = (num_bits_ + 7) / 8; - result.super_block_rank_bytes = super_block_rank_.AsConstBytes().size(); - result.basic_block_rank_bytes = basic_block_rank_.AsConstBytes().size(); - result.select1_samples_bytes = select1_samples_.AsConstBytes().size(); - result.select0_samples_bytes = select0_samples_.AsConstBytes().size(); - result.total_bytes = - result.super_block_rank_bytes + result.basic_block_rank_bytes + - result.select1_samples_bytes + result.select0_samples_bytes; - return result; - } - - /** - * @brief Log memory usage of internal components. - */ - void memory_report() const { - const auto diagnostics = diagnostics_bytes(); - const double source_bytes = - static_cast(diagnostics.source_bitvector_bytes); - const auto log_bytes = [&](std::string_view label, size_t bytes) { - const double percentage = - source_bytes > 0.0 ? 100.0 * static_cast(bytes) / source_bytes - : 0.0; - spdlog::info("BitVector {}: {} bytes ({:.2f}% of source)", label, bytes, - percentage); - }; - log_bytes("source_bitvector", diagnostics.source_bitvector_bytes); - log_bytes("super_block_rank", diagnostics.super_block_rank_bytes); - log_bytes("basic_block_rank", diagnostics.basic_block_rank_bytes); - log_bytes("select1_samples", diagnostics.select1_samples_bytes); - log_bytes("select0_samples", diagnostics.select0_samples_bytes); - log_bytes("total", diagnostics.total_bytes); - } -#endif - /** - * @brief Construct from an external array of 64-bit words. - * @param - * bit_vector Backing data, not owned. - * @param num_bits Number of valid - * bits in the vector. - */ - explicit BitVector(std::span bit_vector, size_t num_bits) - : num_bits_(std::min(num_bits, bit_vector.size() * kWordSize)), - padded_size_(((num_bits_ + kWordSize - 1) / kWordSize) * kWordSize), - bits_(bit_vector) { - build_rank(); - build_select(); - } - - /** - * @brief Returns the number of valid bits. - */ - size_t size() const { return num_bits_; } - - /** - * @brief Returns the bit at the given position. - * @param pos Bit - * index in [0, size()). - */ - int operator[](size_t pos) const { - size_t word_idx = pos / kWordSize; - size_t bit_off = pos % kWordSize; - - return (bits_[word_idx] >> bit_off) & 1; - } - - /** - * @brief Rank of 1s up to position pos (exclusive). - * @param pos Bit - * index in [0, size()]. - * @return Number of 1s in [0, pos). - */ - uint64_t rank(size_t pos) const { - if (pos >= bits_.size() * kWordSize) [[unlikely]] { - return max_rank_; - } - - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - uint64_t b_block = pos / kBasicBlockSize; - uint64_t s_block = pos / kSuperBlockSize; - // Precomputed rank - uint64_t result = super_block_rank[s_block] + basic_block_rank[b_block]; - // Basic block tail - result += rank_512(&bits_[b_block * kWordsPerBlock], - pos - (b_block * kBasicBlockSize)); - return result; - } - - /** - * @brief Rank of 0s up to position pos (exclusive). - * @param pos Bit - * index in [0, size()]. - * @return Number of 0s in [0, pos). - */ - uint64_t rank0(size_t pos) const { - if (pos >= bits_.size() * kWordSize) [[unlikely]] { - return bits_.size() * kWordSize - max_rank_; - } - return pos - rank(pos); - } - - /** - * @brief Select the position of the rank-th 1-bit (1-indexed). - * - * @param rank 1-based rank of the 1-bit to select. - * @return Bit index, or - * size() if rank is out of range. - */ - uint64_t select(size_t rank) const { - if (rank > max_rank_) [[unlikely]] { - return num_bits_; - } - if (rank == 0) [[unlikely]] { - return 0; - } - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - uint64_t s_block = find_superblock(rank); - rank -= super_block_rank[s_block]; - auto pos = find_basicblock_is(rank, s_block); - rank -= basic_block_rank[pos]; - pos *= kWordsPerBlock; - - // Final search - if (pos + kWordsPerBlock - 1 < kWordsPerBlock) [[unlikely]] { - size_t ones = std::popcount(bits_[pos]); - while (pos < bits_.size() && ones < rank) { - rank -= ones; - ones = std::popcount(bits_[++pos]); - } - return kWordSize * pos + select_64(bits_[pos], rank - 1); - } - return kWordSize * pos + select_512(&bits_[pos], rank - 1); - } - - /** - * @brief Select the position of the rank0-th 0-bit (1-indexed). - * - * @param rank0 1-based rank of the 0-bit to select. - * @return Bit index, - * or size() if rank0 is out of range. - */ - uint64_t select0(size_t rank0) const { - if (rank0 > num_bits_ - max_rank_) [[unlikely]] { - return num_bits_; - } - if (rank0 == 0) [[unlikely]] { - return 0; - } - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - uint64_t s_block = find_superblock_zeros(rank0); - rank0 -= kSuperBlockSize * s_block - super_block_rank[s_block]; - auto pos = find_basicblock_is_zeros(rank0, s_block); - auto pos_in_super_block = pos & (kBlocksPerSuperBlock - 1); - rank0 -= kBasicBlockSize * pos_in_super_block - basic_block_rank[pos]; - pos *= kWordsPerBlock; - - // Final search - if (pos + kWordsPerBlock - 1 < kWordsPerBlock) [[unlikely]] { - size_t zeros = std::popcount(~bits_[pos]); - while (pos < bits_.size() && zeros < rank0) { - rank0 -= zeros; - zeros = std::popcount(~bits_[++pos]); - } - return kWordSize * pos + select_64(~bits_[pos], rank0 - 1); - } - return kWordSize * pos + select0_512(&bits_[pos], rank0 - 1); - } - - /** - * @brief Convert to a binary string (debug helper). - */ - std::string to_string() const { - std::string result; - result.reserve(num_bits_); - - for (size_t i = 0; i < num_bits_; i++) { - result.push_back(operator[](i) ? '1' : '0'); - } - - return result; - } -}; - -/** - * @brief Interleaved, owning bit vector with rank and select. - * - * - * @details - * This variant interleaves data with local rank metadata to reduce - * cache - * misses for rank queries. It copies input bits into an interleaved - * layout. - * - * Based on: - * "SPIDER: Improved Succinct Rank and Select - * Performance" - * Matthew D. Laws, Jocelyn Bliven, Kit Conklin, Elyes Laalai, - * Samuel McCauley, - * Zach S. Sturdevant - */ -class BitVectorInterleaved { - private: - constexpr static size_t kWordSize = 64; - constexpr static size_t kSuperBlockRankIntSize = 64; - constexpr static size_t kBasicBlockRankIntSize = 16; - /** - * 496 bits data + 16 bit local rank - */ - constexpr static size_t kBasicBlockSize = 496; - /** - * 63488 = 496 * 128, so position of superblock can be obtained - * from the position of basic block by dividing on 128 or - * right shift on 7 bits which is cheaper then performing another - * division. - */ - constexpr static size_t kSuperBlockSize = 63488; - constexpr static size_t kBlocksPerSuperBlock = 128; - constexpr static size_t kWordsPerBlock = 8; - - const size_t num_bits_; - std::vector bits_interleaved; - std::vector super_block_rank_; - - class BitReader { - size_t iterator_64_ = 0; - size_t offset_size_ = 0; - size_t offset_bits_ = 0; - std::span bits_; - - public: - BitReader(std::span bits) : bits_(bits) {} - uint64_t ReadBits64(size_t num_bits) { - if (num_bits > 64) { - num_bits = 64; - } - uint64_t result = offset_bits_ & first_bits_mask(num_bits); - if (offset_size_ >= num_bits) { - offset_bits_ >>= num_bits; - offset_size_ -= num_bits; - return result; - } - uint64_t next = iterator_64_ < bits_.size() ? bits_[iterator_64_++] : 0; - result ^= (next & first_bits_mask(num_bits - offset_size_)) - << offset_size_; - offset_bits_ = (num_bits - offset_size_ == 64) - ? 0 - : next >> (num_bits - offset_size_); - offset_size_ = 64 - (num_bits - offset_size_); - return result; - } - }; - - public: - /** - * @brief Construct from an external array of 64-bit words. - * @param - * bit_vector Backing data to copy and interleave. - * @param num_bits Number - * of valid bits in the vector. - */ - explicit BitVectorInterleaved(std::span bit_vector, - size_t num_bits) - : num_bits_(std::min(num_bits, bit_vector.size() * kWordSize)) { - build_rank_interleaved(bit_vector, num_bits); - } - - /** - * @brief Mask with the lowest num bits set. - */ - static inline uint64_t first_bits_mask(size_t num) { - return num >= 64 ? UINT64_MAX : ((1llu << num) - 1); - } - - /** - * @brief Returns the number of valid bits. - */ - size_t size() const { return num_bits_; } - - /** - * @brief Returns the bit at the given position. - * @param pos Bit - * index in [0, size()). - */ - int operator[](size_t pos) const { - size_t block_id = pos / kBasicBlockSize; - size_t block_bit = pos - block_id * kBasicBlockSize; - size_t word_id = block_id * kWordsPerBlock + block_bit / kWordSize; - size_t word_bit = block_bit % kWordSize; - kWordSize; - - return (bits_interleaved[word_id] >> word_bit) & 1; - } - - /** - * @brief Build the interleaved layout and rank index. - * @param bits - * Source bit vector as 64-bit words. - * @param num_bits Number of valid - * bits in the source. - */ - void build_rank_interleaved(std::span bits, size_t num_bits) { - size_t num_superblocks = 1 + (num_bits_ - 1) / kSuperBlockSize; - super_block_rank_.resize(num_superblocks); - size_t num_basicblocks = 1 + (num_bits_ - 1) / kBasicBlockSize; - bits_interleaved.resize(num_basicblocks * (512 / kWordSize)); - - uint64_t super_block_sum = 0; - uint16_t basic_block_sum = 0; - auto bit_reader = BitReader(bits); - - for (size_t i = 0; i * kBasicBlockSize < num_bits; ++i) { - if (i % (kSuperBlockSize / kBasicBlockSize) == 0) { - super_block_sum += basic_block_sum; - super_block_rank_[i / (kSuperBlockSize / kBasicBlockSize)] = - super_block_sum; - basic_block_sum = 0; - } - bits_interleaved[i * (kWordsPerBlock) + 7] = - static_cast(basic_block_sum) << 48; - - for (size_t j = 0; j < 7 && kWordSize * (i + j) < num_bits; ++j) { - bits_interleaved[i * (kWordsPerBlock) + j] = - bit_reader.ReadBits64(std::min( - 64ull, num_bits - i * kBasicBlockSize + j * kWordSize)); - basic_block_sum += - std::popcount(bits_interleaved[i * (kWordsPerBlock) + j]); - } - if ((i + 7) * kWordSize < num_bits) { - auto v = bit_reader.ReadBits64(std::min( - 48ull, num_bits - (i * kBasicBlockSize + 7 * kWordSize))); - bits_interleaved[i * (kWordsPerBlock) + 7] ^= v; - basic_block_sum += std::popcount(v); - } - } - } - - /** - * @brief Rank of 1s up to position pos (exclusive). - * @param pos Bit - * index in [0, size()]. - * @return Number of 1s in [0, pos). - */ - uint64_t rank(size_t pos) const { - // Multiplication/devisions - uint64_t b_block = pos / kBasicBlockSize; - uint64_t s_block = b_block / kBlocksPerSuperBlock; - uint64_t b_block_pos = b_block * kWordsPerBlock; - // Super block rank - uint64_t result = super_block_rank_[s_block]; - /** - * Ok, so here's quite the important factor to load 512-bit region - * at &bits_interleaved[b_block_pos], we store local rank as 16 last - * bits of it. Prefetch should guarantee but seems like there is no - * need for it. - */ - // __builtin_prefetch(&bits_interleaved[b_block_pos]); - result += rank_512(&bits_interleaved[b_block_pos], - pos - (b_block * kBasicBlockSize)); - result += bits_interleaved[b_block_pos + 7] >> 48; - return result; - } - - /** - * @brief Convert to a binary string (debug helper). - */ - std::string to_string() const { - std::string result; - result.reserve(num_bits_); - - for (size_t i = 0; i < num_bits_; i++) { - result.push_back(operator[](i) ? '1' : '0'); - } - - return result; - } -}; - -} // namespace pixie +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +#ifdef PIXIE_DIAGNOSTICS +#include +#endif + +namespace pixie { + +/** + * @brief Non-interleaved, non-owning bit vector with rank and select. + * + * + * @details + * This is a two-level rank/select index for a bit vector stored + * externally as + * 64-bit words. The layout follows ideas from: + * + * {1} + * "SPIDER: Improved Succinct Rank and Select Performance" + * Matthew D. Laws, + * Jocelyn Bliven, Kit Conklin, Elyes Laalai, Samuel McCauley, + * Zach S. + * Sturdevant + * https://github.com/williams-cs/spider + * + * {2} "Engineering + * compact data structures for rank and select queries on + * bit vectors" + * Kurpicz F. + * https://github.com/pasta-toolbox/bit_vector + * + * Structure + * overview: + * - Super blocks of 2^16 bits with 64-bit ranks (~0.98% + * overhead). + * - Basic blocks of 512 bits with 16-bit ranks (~3.125% + * overhead). + * - Select samples every 16384 bits (~0.39% overhead). + * + * + * Rank: 2 table lookups plus SIMD popcount in the 512-bit block. + * + * Select: + + * * - Start from a sampled super block. + * - SIMD linear scan to find the super + * block. + * - SIMD linear scan to find the basic block. + * + * This variant does + * not interleave data and index, favoring simpler scans. + */ +class BitVector { + private: + constexpr static size_t kWordSize = 64; + constexpr static size_t kSuperBlockRankIntSize = 64; + constexpr static size_t kBasicBlockRankIntSize = 16; + constexpr static size_t kBasicBlockSize = 512; + constexpr static size_t kWordsPerBlock = 8; + constexpr static size_t kSuperBlockSize = 65536; + constexpr static size_t kBlocksPerSuperBlock = 128; + constexpr static size_t kSelectSampleFrequency = 16384; + + alignas(64) uint64_t delta_super[8]; + alignas(64) uint16_t delta_basic[32]; + + AlignedStorage super_block_rank_; // 64-bit global prefix sums + AlignedStorage basic_block_rank_; // 16-bit local prefix sums + AlignedStorage select1_samples_; // 64-bit global positions + AlignedStorage select0_samples_; // 64-bit global positions + size_t num_bits_{}; + size_t padded_size_{}; + size_t max_rank_{}; + + std::span bits_; + + size_t logical_word_count() const { + return (num_bits_ + kWordSize - 1) / kWordSize; + } + + size_t logical_word_bits(size_t word_index) const { + const size_t begin = word_index * kWordSize; + if (begin >= num_bits_) { + return 0; + } + return std::min(kWordSize, num_bits_ - begin); + } + + uint64_t logical_word(size_t word_index) const { + if (word_index >= bits_.size()) { + return 0; + } + const size_t bits = logical_word_bits(word_index); + if (bits == 0) { + return 0; + } + if (bits == kWordSize) { + return bits_[word_index]; + } + return bits_[word_index] & first_bits_mask(bits); + } + + uint64_t rank_in_basic_block(size_t basic_block, size_t offset) const { + if (offset == 0) { + return 0; + } + const size_t first_word = basic_block * kWordsPerBlock; + if (first_word + kWordsPerBlock <= bits_.size()) { + return rank_512(&bits_[first_word], offset); + } + + uint64_t result = 0; + size_t word_index = first_word; + while (offset >= kWordSize) { + result += std::popcount(logical_word(word_index)); + offset -= kWordSize; + ++word_index; + } + if (offset != 0) { + result += + std::popcount(logical_word(word_index) & first_bits_mask(offset)); + } + return result; + } + + uint64_t select_in_words(size_t first_word, size_t rank, bool value) const { + const size_t first_bit = first_word * kWordSize; + if (first_bit + kBasicBlockSize <= num_bits_ && + first_word + kWordsPerBlock <= bits_.size()) { + return value ? first_bit + select_512(&bits_[first_word], rank - 1) + : first_bit + select0_512(&bits_[first_word], rank - 1); + } + + for (size_t word_index = first_word; word_index < logical_word_count(); + ++word_index) { + const uint64_t word = logical_word(word_index); + const uint64_t candidates = + value ? word + : (~word & first_bits_mask(logical_word_bits(word_index))); + const size_t count = std::popcount(candidates); + if (rank > count) { + rank -= count; + continue; + } + return word_index * kWordSize + select_64(candidates, rank - 1); + } + return num_bits_; + } + + /** + * @brief Precompute rank for fast queries. + */ + void build_rank() { + size_t num_superblocks = + 8 + (padded_size_ == 0 ? 0 : (padded_size_ - 1) / kSuperBlockSize); + // Add more blocks to ease SIMD processing + // num_basicblocks to fully cover superblock, i.e. 128 + // This reduces branching in select + num_superblocks = ((num_superblocks + 7) / 8) * 8; + size_t num_basicblocks = num_superblocks * kBlocksPerSuperBlock; + super_block_rank_.resize(num_superblocks * 64); + basic_block_rank_.resize(num_basicblocks * 16); + + auto super_block_rank = super_block_rank_.As64BitInts(); + auto basic_block_rank = basic_block_rank_.As16BitInts(); + + uint64_t super_block_sum = 0; + uint64_t basic_block_sum = 0; + + for (size_t i = 0; i / kBasicBlockSize < basic_block_rank.size(); + i += kWordSize) { + if (i % kSuperBlockSize == 0) { + super_block_sum += basic_block_sum; + super_block_rank[i / kSuperBlockSize] = super_block_sum; + basic_block_sum = 0; + } + if (i % kBasicBlockSize == 0) { + basic_block_rank[i / kBasicBlockSize] = + static_cast(basic_block_sum); + } + if (i / kWordSize < logical_word_count()) { + basic_block_sum += std::popcount(logical_word(i / kWordSize)); + } + } + max_rank_ = super_block_sum + basic_block_sum; + } + + /** + * @brief Calculate select samples. + */ + void build_select() { + uint64_t milestone = kSelectSampleFrequency; + uint64_t milestone0 = kSelectSampleFrequency; + uint64_t rank = 0; + uint64_t rank0 = 0; + + size_t num_one_samples = + 1 + (max_rank_ + kSelectSampleFrequency - 1) / kSelectSampleFrequency; + size_t num_zero_samples = + 1 + (num_bits_ - max_rank_ + kSelectSampleFrequency - 1) / + kSelectSampleFrequency; + + select1_samples_.resize(num_one_samples * 64); + select0_samples_.resize(num_zero_samples * 64); + auto select1_samples = select1_samples_.As64BitInts(); + auto select0_samples = select0_samples_.As64BitInts(); + + select1_samples[0] = 0; + select0_samples[0] = 0; + + size_t num_zeros = 1, num_ones = 1; + + for (size_t i = 0; i < logical_word_count(); ++i) { + const uint64_t word = logical_word(i); + const auto ones = std::popcount(word); + const auto zeros = logical_word_bits(i) - ones; + if (rank + ones >= milestone) { + auto pos = select_64(word, milestone - rank - 1); + // TODO: try including global rank into select samples to save + // a cache miss on global rank scan + select1_samples[num_ones++] = (64 * i + pos) / kSuperBlockSize; + milestone += kSelectSampleFrequency; + } + if (rank0 + zeros >= milestone0) { + const uint64_t zero_word = + ~word & first_bits_mask(logical_word_bits(i)); + auto pos = select_64(zero_word, milestone0 - rank0 - 1); + select0_samples[num_zeros++] = (64 * i + pos) / kSuperBlockSize; + milestone0 += kSelectSampleFrequency; + } + rank += ones; + rank0 += zeros; + } + + for (size_t i = 0; i < 8; ++i) { + delta_super[i] = i * kSuperBlockSize; + } + for (size_t i = 0; i < 32; ++i) { + delta_basic[i] = i * kBasicBlockSize; + } + } + + /** + * @brief First step of the select operation. + * @param rank 1-based + * rank of the 1-bit to locate. + */ + uint64_t find_superblock(uint64_t rank) const { + auto select1_samples = select1_samples_.AsConst64BitInts(); + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + + uint64_t left = select1_samples[rank / kSelectSampleFrequency]; + + while (left + 7 < super_block_rank.size()) { + auto len = lower_bound_8x64(&super_block_rank[left], rank); + if (len < 8) { + return left + len - 1; + } + left += 8; + } + if (left + 3 < super_block_rank.size()) { + auto len = lower_bound_4x64(&super_block_rank[left], rank); + if (len < 4) { + return left + len - 1; + } + left += 4; + } + while (left < super_block_rank.size() && super_block_rank[left] < rank) { + left++; + } + return left - 1; + } + + /** + * @brief First step of the select0 operation. + * @param rank0 1-based + * rank of the 0-bit to locate. + */ + uint64_t find_superblock_zeros(uint64_t rank0) const { + auto select0_samples = select0_samples_.AsConst64BitInts(); + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + + uint64_t left = select0_samples[rank0 / kSelectSampleFrequency]; + + while (left + 7 < super_block_rank.size()) { + auto len = lower_bound_delta_8x64(&super_block_rank[left], rank0, + delta_super, kSuperBlockSize * left); + if (len < 8) { + return left + len - 1; + } + left += 8; + } + if (left + 3 < super_block_rank.size()) { + auto len = lower_bound_delta_4x64(&super_block_rank[left], rank0, + delta_super, kSuperBlockSize * left); + if (len < 4) { + return left + len - 1; + } + left += 4; + } + while (left < super_block_rank.size() && + kSuperBlockSize * left - super_block_rank[left] < rank0) { + left++; + } + return left - 1; + } + + /** + * @brief SIMD-optimized linear scan. + * @param local_rank Rank within + * the super block. + * @param s_block Super block index. + * @details + * + * Processes 32 16-bit entries at once (full cache line), so there is at most + + * * 4 iterations. + */ + uint64_t find_basicblock(uint16_t local_rank, uint64_t s_block) const { + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + for (size_t pos = 0; pos < kBlocksPerSuperBlock; pos += 32) { + auto count = lower_bound_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); + if (count < 32) { + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + } + return kBlocksPerSuperBlock * s_block + kBlocksPerSuperBlock - 1; + } + + /** + * @brief SIMD-optimized linear scan. + * @param local_rank0 Rank of + * zeros within the super block. + * @param s_block Super block index. + * + * @details + * Processes 32 16-bit entries at once (full cache line), so + * there is at most + * 4 iterations. + */ + uint64_t find_basicblock_zeros(uint16_t local_rank0, uint64_t s_block) const { + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + for (size_t pos = 0; pos < kBlocksPerSuperBlock; pos += 32) { + auto count = lower_bound_delta_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, + delta_basic, kBasicBlockSize * pos); + if (count < 32) { + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + } + return kBlocksPerSuperBlock * s_block + kBlocksPerSuperBlock - 1; + } + + /** + * @brief Interpolation search with SIMD optimization. + * @param + * local_rank Rank within the super block. + * @param s_block Super block + * index. + * @details + * Similar to find_basicblock but initial guess is + * based on linear + * interpolation, for random data it should make initial + * guess correct + * most of the times, we start from the 32 wide block with + * interpolation + * guess at the center, if we see that select result lie in + * lower blocks + * we backoff to find_basicblock + */ + uint64_t find_basicblock_is(uint16_t local_rank, uint64_t s_block) const { + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + auto lower = super_block_rank[s_block]; + auto upper = super_block_rank[s_block + 1]; + + uint64_t pos = kBlocksPerSuperBlock * local_rank / (upper - lower); + pos = pos + 16 < 32 ? 0 : (pos - 16); + pos = pos > 96 ? 96 : pos; + while (pos < 96) { + auto count = lower_bound_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); + if (count == 0) { + return find_basicblock(local_rank, s_block); + } + if (count < 32) { + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + pos += 32; + } + pos = 96; + auto count = lower_bound_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); + if (count == 0) { + return find_basicblock(local_rank, s_block); + } + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + + /** + * @brief Interpolation search with SIMD optimization. + * @param + * local_rank0 Rank of zeros within the super block. + * @param s_block Super + * block index. + * @details + * Similar to find_basicblock_zeros but + * initial guess is based on linear + * interpolation, for random data it + * should make initial guess correct + * most of the times, we start from the + * 32 wide block with interpolation + * guess at the center, if we see that + * select result lie in lower blocks + * we backoff to find_basicblock_zeros + + */ + uint64_t find_basicblock_is_zeros(uint16_t local_rank0, + uint64_t s_block) const { + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + auto lower = kSuperBlockSize * s_block - super_block_rank[s_block]; + auto upper = + kSuperBlockSize * (s_block + 1) - super_block_rank[s_block + 1]; + + uint64_t pos = kBlocksPerSuperBlock * local_rank0 / (upper - lower); + pos = pos + 16 < 32 ? 0 : (pos - 16); + pos = pos > 96 ? 96 : pos; + while (pos < 96) { + auto count = lower_bound_delta_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, + delta_basic, kBasicBlockSize * pos); + if (count == 0) { + return find_basicblock_zeros(local_rank0, s_block); + } + if (count < 32) { + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + pos += 32; + } + pos = 96; + auto count = lower_bound_delta_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, + delta_basic, kBasicBlockSize * pos); + if (count == 0) { + return find_basicblock_zeros(local_rank0, s_block); + } + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + + public: + BitVector() = default; + BitVector(const BitVector&) = default; + BitVector(BitVector&&) noexcept = default; + BitVector& operator=(const BitVector&) = default; + BitVector& operator=(BitVector&&) noexcept = default; + +#ifdef PIXIE_DIAGNOSTICS + struct DiagnosticsBytes { + size_t source_bitvector_bytes = 0; + size_t super_block_rank_bytes = 0; + size_t basic_block_rank_bytes = 0; + size_t select1_samples_bytes = 0; + size_t select0_samples_bytes = 0; + size_t total_bytes = 0; + }; + + /** + * @brief Returns the number of bytes used by each internal component. + */ + DiagnosticsBytes diagnostics_bytes() const { + DiagnosticsBytes result; + result.source_bitvector_bytes = (num_bits_ + 7) / 8; + result.super_block_rank_bytes = super_block_rank_.AsConstBytes().size(); + result.basic_block_rank_bytes = basic_block_rank_.AsConstBytes().size(); + result.select1_samples_bytes = select1_samples_.AsConstBytes().size(); + result.select0_samples_bytes = select0_samples_.AsConstBytes().size(); + result.total_bytes = + result.super_block_rank_bytes + result.basic_block_rank_bytes + + result.select1_samples_bytes + result.select0_samples_bytes; + return result; + } + + /** + * @brief Log memory usage of internal components. + */ + void memory_report() const { + const auto diagnostics = diagnostics_bytes(); + const double source_bytes = + static_cast(diagnostics.source_bitvector_bytes); + const auto log_bytes = [&](std::string_view label, size_t bytes) { + const double percentage = + source_bytes > 0.0 ? 100.0 * static_cast(bytes) / source_bytes + : 0.0; + spdlog::info("BitVector {}: {} bytes ({:.2f}% of source)", label, bytes, + percentage); + }; + log_bytes("source_bitvector", diagnostics.source_bitvector_bytes); + log_bytes("super_block_rank", diagnostics.super_block_rank_bytes); + log_bytes("basic_block_rank", diagnostics.basic_block_rank_bytes); + log_bytes("select1_samples", diagnostics.select1_samples_bytes); + log_bytes("select0_samples", diagnostics.select0_samples_bytes); + log_bytes("total", diagnostics.total_bytes); + } +#endif + /** + * @brief Construct from an external array of 64-bit words. + * @param + * bit_vector Backing data, not owned. + * @param num_bits Number of valid + * bits in the vector. + */ + explicit BitVector(std::span bit_vector, size_t num_bits) + : num_bits_(std::min(num_bits, bit_vector.size() * kWordSize)), + padded_size_(((num_bits_ + kWordSize - 1) / kWordSize) * kWordSize), + bits_(bit_vector) { + build_rank(); + build_select(); + } + + /** + * @brief Returns the number of valid bits. + */ + size_t size() const { return num_bits_; } + + /** + * @brief Returns the bit at the given position. + * @param pos Bit + * index in [0, size()). + */ + int operator[](size_t pos) const { + size_t word_idx = pos / kWordSize; + size_t bit_off = pos % kWordSize; + + return (bits_[word_idx] >> bit_off) & 1; + } + + /** + * @brief Rank of 1s up to position pos (exclusive). + * @param pos Bit + * index in [0, size()]. + * @return Number of 1s in [0, pos). + */ + uint64_t rank(size_t pos) const { + if (pos >= num_bits_) [[unlikely]] { + return max_rank_; + } + + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + uint64_t b_block = pos / kBasicBlockSize; + uint64_t s_block = pos / kSuperBlockSize; + // Precomputed rank + uint64_t result = super_block_rank[s_block] + basic_block_rank[b_block]; + // Basic block tail + result += rank_in_basic_block(b_block, pos - (b_block * kBasicBlockSize)); + return result; + } + + /** + * @brief Rank of 0s up to position pos (exclusive). + * @param pos Bit + * index in [0, size()]. + * @return Number of 0s in [0, pos). + */ + uint64_t rank0(size_t pos) const { + if (pos >= num_bits_) [[unlikely]] { + return num_bits_ - max_rank_; + } + return pos - rank(pos); + } + + /** + * @brief Select the position of the rank-th 1-bit (1-indexed). + * + * @param rank 1-based rank of the 1-bit to select. + * @return Bit index, or + * size() if rank is out of range. + */ + uint64_t select(size_t rank) const { + if (rank > max_rank_) [[unlikely]] { + return num_bits_; + } + if (rank == 0) [[unlikely]] { + return 0; + } + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + uint64_t s_block = find_superblock(rank); + rank -= super_block_rank[s_block]; + auto pos = find_basicblock_is(rank, s_block); + rank -= basic_block_rank[pos]; + return select_in_words(pos * kWordsPerBlock, rank, true); + } + + /** + * @brief Select the position of the rank0-th 0-bit (1-indexed). + * + * @param rank0 1-based rank of the 0-bit to select. + * @return Bit index, + * or size() if rank0 is out of range. + */ + uint64_t select0(size_t rank0) const { + if (rank0 > num_bits_ - max_rank_) [[unlikely]] { + return num_bits_; + } + if (rank0 == 0) [[unlikely]] { + return 0; + } + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + uint64_t s_block = find_superblock_zeros(rank0); + rank0 -= kSuperBlockSize * s_block - super_block_rank[s_block]; + auto pos = find_basicblock_is_zeros(rank0, s_block); + auto pos_in_super_block = pos & (kBlocksPerSuperBlock - 1); + rank0 -= kBasicBlockSize * pos_in_super_block - basic_block_rank[pos]; + return select_in_words(pos * kWordsPerBlock, rank0, false); + } + + /** + * @brief Convert to a binary string (debug helper). + */ + std::string to_string() const { + std::string result; + result.reserve(num_bits_); + + for (size_t i = 0; i < num_bits_; i++) { + result.push_back(operator[](i) ? '1' : '0'); + } + + return result; + } +}; + +/** + * @brief Interleaved, owning bit vector with rank and select. + * + * + * @details + * This variant interleaves data with local rank metadata to reduce + * cache + * misses for rank queries. It copies input bits into an interleaved + * layout. + * + * Based on: + * "SPIDER: Improved Succinct Rank and Select + * Performance" + * Matthew D. Laws, Jocelyn Bliven, Kit Conklin, Elyes Laalai, + * Samuel McCauley, + * Zach S. Sturdevant + */ +class BitVectorInterleaved { + private: + constexpr static size_t kWordSize = 64; + constexpr static size_t kSuperBlockRankIntSize = 64; + constexpr static size_t kBasicBlockRankIntSize = 16; + /** + * 496 bits data + 16 bit local rank + */ + constexpr static size_t kBasicBlockSize = 496; + /** + * 63488 = 496 * 128, so position of superblock can be obtained + * from the position of basic block by dividing on 128 or + * right shift on 7 bits which is cheaper then performing another + * division. + */ + constexpr static size_t kSuperBlockSize = 63488; + constexpr static size_t kBlocksPerSuperBlock = 128; + constexpr static size_t kWordsPerBlock = 8; + + const size_t num_bits_; + std::vector bits_interleaved; + std::vector super_block_rank_; + + class BitReader { + size_t iterator_64_ = 0; + size_t offset_size_ = 0; + size_t offset_bits_ = 0; + std::span bits_; + + public: + BitReader(std::span bits) : bits_(bits) {} + uint64_t ReadBits64(size_t num_bits) { + if (num_bits > 64) { + num_bits = 64; + } + uint64_t result = offset_bits_ & first_bits_mask(num_bits); + if (offset_size_ >= num_bits) { + offset_bits_ >>= num_bits; + offset_size_ -= num_bits; + return result; + } + uint64_t next = iterator_64_ < bits_.size() ? bits_[iterator_64_++] : 0; + result ^= (next & first_bits_mask(num_bits - offset_size_)) + << offset_size_; + offset_bits_ = (num_bits - offset_size_ == 64) + ? 0 + : next >> (num_bits - offset_size_); + offset_size_ = 64 - (num_bits - offset_size_); + return result; + } + }; + + public: + /** + * @brief Construct from an external array of 64-bit words. + * @param + * bit_vector Backing data to copy and interleave. + * @param num_bits Number + * of valid bits in the vector. + */ + explicit BitVectorInterleaved(std::span bit_vector, + size_t num_bits) + : num_bits_(std::min(num_bits, bit_vector.size() * kWordSize)) { + build_rank_interleaved(bit_vector, num_bits); + } + + /** + * @brief Mask with the lowest num bits set. + */ + static inline uint64_t first_bits_mask(size_t num) { + return num >= 64 ? UINT64_MAX : ((1llu << num) - 1); + } + + /** + * @brief Returns the number of valid bits. + */ + size_t size() const { return num_bits_; } + + /** + * @brief Returns the bit at the given position. + * @param pos Bit + * index in [0, size()). + */ + int operator[](size_t pos) const { + size_t block_id = pos / kBasicBlockSize; + size_t block_bit = pos - block_id * kBasicBlockSize; + size_t word_id = block_id * kWordsPerBlock + block_bit / kWordSize; + size_t word_bit = block_bit % kWordSize; + kWordSize; + + return (bits_interleaved[word_id] >> word_bit) & 1; + } + + /** + * @brief Build the interleaved layout and rank index. + * @param bits + * Source bit vector as 64-bit words. + * @param num_bits Number of valid + * bits in the source. + */ + void build_rank_interleaved(std::span bits, size_t num_bits) { + size_t num_superblocks = 1 + (num_bits_ - 1) / kSuperBlockSize; + super_block_rank_.resize(num_superblocks); + size_t num_basicblocks = 1 + (num_bits_ - 1) / kBasicBlockSize; + bits_interleaved.resize(num_basicblocks * (512 / kWordSize)); + + uint64_t super_block_sum = 0; + uint16_t basic_block_sum = 0; + auto bit_reader = BitReader(bits); + + for (size_t i = 0; i * kBasicBlockSize < num_bits; ++i) { + if (i % (kSuperBlockSize / kBasicBlockSize) == 0) { + super_block_sum += basic_block_sum; + super_block_rank_[i / (kSuperBlockSize / kBasicBlockSize)] = + super_block_sum; + basic_block_sum = 0; + } + bits_interleaved[i * (kWordsPerBlock) + 7] = + static_cast(basic_block_sum) << 48; + + for (size_t j = 0; j < 7 && kWordSize * (i + j) < num_bits; ++j) { + bits_interleaved[i * (kWordsPerBlock) + j] = + bit_reader.ReadBits64(std::min( + 64ull, num_bits - i * kBasicBlockSize + j * kWordSize)); + basic_block_sum += + std::popcount(bits_interleaved[i * (kWordsPerBlock) + j]); + } + if ((i + 7) * kWordSize < num_bits) { + auto v = bit_reader.ReadBits64(std::min( + 48ull, num_bits - (i * kBasicBlockSize + 7 * kWordSize))); + bits_interleaved[i * (kWordsPerBlock) + 7] ^= v; + basic_block_sum += std::popcount(v); + } + } + } + + /** + * @brief Rank of 1s up to position pos (exclusive). + * @param pos Bit + * index in [0, size()]. + * @return Number of 1s in [0, pos). + */ + uint64_t rank(size_t pos) const { + // Multiplication/devisions + uint64_t b_block = pos / kBasicBlockSize; + uint64_t s_block = b_block / kBlocksPerSuperBlock; + uint64_t b_block_pos = b_block * kWordsPerBlock; + // Super block rank + uint64_t result = super_block_rank_[s_block]; + /** + * Ok, so here's quite the important factor to load 512-bit region + * at &bits_interleaved[b_block_pos], we store local rank as 16 last + * bits of it. Prefetch should guarantee but seems like there is no + * need for it. + */ + // __builtin_prefetch(&bits_interleaved[b_block_pos]); + result += rank_512(&bits_interleaved[b_block_pos], + pos - (b_block * kBasicBlockSize)); + result += bits_interleaved[b_block_pos + 7] >> 48; + return result; + } + + /** + * @brief Convert to a binary string (debug helper). + */ + std::string to_string() const { + std::string result; + result.reserve(num_bits_); + + for (size_t i = 0; i < num_bits_; i++) { + result.push_back(operator[](i) ? '1' : '0'); + } + + return result; + } +}; + +} // namespace pixie diff --git a/include/pixie/experimental/excess.h b/include/pixie/experimental/excess.h index de8de0f..ef64dde 100644 --- a/include/pixie/experimental/excess.h +++ b/include/pixie/experimental/excess.h @@ -647,4 +647,62 @@ static inline void excess_positions_512_expand_avx512(const uint64_t* s, #endif } +struct ExcessByteLut { + uint8_t masks[256][17]; // target index: T + 8 + int8_t deltas[256]; + + constexpr ExcessByteLut() : masks{}, deltas{} { + for (int b = 0; b < 256; ++b) { + int pop = 0; + for (int i = 0; i < 8; ++i) { + if ((b >> i) & 1) { + pop++; + } + } + deltas[b] = static_cast(2 * pop - 8); + + for (int t = -8; t <= 8; ++t) { + uint8_t mask = 0; + int cur_pop = 0; + for (int i = 0; i < 8; ++i) { + if ((b >> i) & 1) { + cur_pop++; + } + int excess = 2 * cur_pop - (i + 1); + if (excess == t) { + mask |= (1 << i); + } + } + masks[b][t + 8] = mask; + } + } + } +}; + +inline constexpr ExcessByteLut kExcessByteLut; + +static inline void excess_positions_512_byte_lut(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + out[0] = out[1] = out[2] = out[3] = 0; + out[4] = out[5] = out[6] = out[7] = 0; + + if (target_x < -512 || target_x > 512) { + return; + } + + const uint8_t* bytes = reinterpret_cast(s); + uint8_t* out_bytes = reinterpret_cast(out); + + int cur = 0; + for (int i = 0; i < 64; ++i) { + const uint8_t b = bytes[i]; + const int target_rel = target_x - cur; + if (target_rel >= -8 && target_rel <= 8) { + out_bytes[i] = kExcessByteLut.masks[b][target_rel + 8]; + } + cur += kExcessByteLut.deltas[b]; + } +} + } // namespace pixie::experimental diff --git a/include/pixie/experimental/rmm_btree.h b/include/pixie/experimental/rmm_btree.h new file mode 100644 index 0000000..1d83a70 --- /dev/null +++ b/include/pixie/experimental/rmm_btree.h @@ -0,0 +1,1986 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace pixie::experimental { + +/** + * @brief Cache-aligned btree implementation of the range min-max index. + * @details `RmMBTree` is a non-owning succinct index over a little-endian bit + * sequence. It provides the `RmMBase` operations by combining a rank/select + * helper with a hierarchy of min-max summaries over fixed 512-bit blocks. Each + * summary stores subtree size, number of ones, total excess, minimum/maximum + * relative excess, and minimum multiplicity, which lets searches skip whole + * subtrees when the requested excess cannot occur there. + * + * The internal layout has three layers: + * - the original bit words, referenced by `std::span`; + * - level 0 block summaries, reconstructed from parent nodes and scanned with + * 128-bit chunk primitives inside each 512-bit block; + * - higher btree levels stored as cache-line-aligned summary nodes. The first + * level above blocks uses compact `int16_t` excess fields because one low + * node covers at most `512 * LowFanout` bits. Higher levels use `int64_t` + * excess and count fields. + * + * Low nodes are one-cache-line-aligned arrays of compact per-child summaries: + * @code + * LowNode + * +-----------------------+ + * | cumulative excess | int16_t prefix_excess[32], prefix through child + * +-----------------------+ + * | child relative min | int16_t min_excess[32] + * +-----------------------+ + * | child relative max | int16_t max_excess[32] + * +-----------------------+ + * | count of minima | uint16_t min_count[32] + * +-----------------------+ + * @endcode + * + * High nodes have the same logical fields, but use wider lanes and a fanout + * chosen from @p HighCacheLines: + * @code + * HighNode + * +-----------------------+ + * | cumulative excess | int64_t prefix_excess[kHighFanout], prefix + * | | through child + * +-----------------------+ + * | child relative min | int64_t min_excess[kHighFanout] + * +-----------------------+ + * | child relative max | int64_t max_excess[kHighFanout] + * +-----------------------+ + * | count of minima | uint64_t min_count[kHighFanout] + * +-----------------------+ + * @endcode + * + * Forward and backward excess searches ascend from the starting block until a + * sibling summary can contain the target, then descend through matching child + * ranges. Node scans use SIMD helpers from `bits.h` when available: low nodes + * compare 16 `int16_t` lanes at a time, and high nodes compare four `int64_t` + * lanes at a time. Scalar fallbacks are kept for partial chunks and targets + * that cannot be represented in the low-node lane type. + * + * @tparam HighCacheLines Number of 64-byte cache lines assigned to one high + * summary node. + * @tparam LowFanout Number of child summaries stored in one low-level node. + */ +template +class RmMBTree : public RmMBase> { + public: + static_assert(HighCacheLines > 0); + static_assert(LowFanout > 0); + + static constexpr std::size_t npos = + RmMBase>::npos; + static constexpr std::size_t kBlockBits = 512; + static constexpr std::size_t kBlockWords = kBlockBits / 64; + static constexpr std::size_t kCacheLineBytes = 64; + static constexpr std::size_t kLowFanout = LowFanout; + static constexpr std::size_t kHighFanout = + std::max(2, (512 * HighCacheLines) / (4 * 64)); + static constexpr std::size_t kMaxFanout = std::max(kLowFanout, kHighFanout); + static_assert(kMaxFanout <= 64); + static_assert( + kBlockBits * kLowFanout <= + static_cast(std::numeric_limits::max())); + + RmMBTree() = default; + RmMBTree(const RmMBTree&) = default; + RmMBTree(RmMBTree&&) noexcept = default; + RmMBTree& operator=(const RmMBTree&) = default; + RmMBTree& operator=(RmMBTree&&) noexcept = default; + + /** + * @brief Construct an RmM btree over an external bit-vector span. + * @details The tree stores a non-owning view of @p words and builds its + * rank/select helper plus min-max summaries over the first @p bit_count bits. + * The optional block-size argument is accepted for API compatibility; this + * implementation uses the fixed 512-bit block size. + * @param words External little-endian 64-bit words that contain the bit + * sequence. + * @param bit_count Number of valid bits in @p words. + */ + explicit RmMBTree(std::span words, + std::size_t bit_count, + std::size_t = kBlockBits) { + build(words, bit_count); + } + + std::size_t size_impl() const { return bit_count_; } + + std::size_t rank1_impl(std::size_t end_position) const { + return rank_index_ ? rank_index_->rank(end_position) : 0; + } + + std::size_t rank0_impl(std::size_t end_position) const { + return rank_index_ ? rank_index_->rank0(end_position) : 0; + } + + std::size_t select1_impl(std::size_t rank) const { + if (!rank_index_ || rank == 0) { + return npos; + } + const std::size_t position = rank_index_->select(rank); + return position < bit_count_ ? position : npos; + } + + std::size_t select0_impl(std::size_t rank) const { + if (!rank_index_ || rank == 0) { + return npos; + } + const std::size_t position = rank_index_->select0(rank); + return position < bit_count_ ? position : npos; + } + + std::size_t rank10_impl(std::size_t end_position) const { + if (end_position <= 1 || bit_count_ == 0) { + return 0; + } + end_position = std::min(end_position, bit_count_); + std::size_t count = 0; + for (std::size_t position = 0; position + 1 < end_position; ++position) { + count += bit(position) == 1 && bit(position + 1) == 0; + } + return count; + } + + std::size_t select10_impl(std::size_t rank) const { + if (rank == 0) { + return npos; + } + for (std::size_t position = 0; position + 1 < bit_count_; ++position) { + if (bit(position) == 1 && bit(position + 1) == 0 && --rank == 0) { + return position; + } + } + return npos; + } + + int excess_impl(std::size_t end_position) const { + end_position = std::min(end_position, bit_count_); + return static_cast( + 2 * static_cast(rank1_impl(end_position)) - + static_cast(end_position)); + } + + std::size_t fwdsearch_impl(std::size_t start_position, int delta) const { + if (start_position >= bit_count_) { + return npos; + } + + const std::size_t block_index = start_position / kBlockBits; + const std::size_t block_begin = block_index * kBlockBits; + const std::size_t start_offset = start_position - block_begin; + const std::size_t block_result = + find_fwd_in_block(block_index, start_offset, delta); + if (block_result != npos) { + return block_result; + } + + const std::int64_t relative_target = + block_excess_at_local(block_index, start_offset) + delta; + const std::int64_t block_start_excess = prefix_excess_impl(block_begin); + const std::int64_t target = block_start_excess + relative_target; + std::size_t level = 0; + std::size_t index = block_index; + while (has_parent_level(level)) { + const std::size_t fanout = fanout_to_parent(level); + const std::size_t parent = index / fanout; + const std::size_t sibling_end = + std::min(level_count(level), parent * fanout + fanout); + const NodeScanResult scan = scan_children_fwd( + level, parent, index % fanout + 1, sibling_end - parent * fanout, + target, prefix_excess_impl(node_end_bit(level, index))); + if (scan.found) { + const std::size_t result = + descend_fwd(level, scan.index, target, scan.node_start_excess); + if (result != npos) { + return result; + } + } + level += 1; + index = parent; + } + return npos; + } + + std::size_t bwdsearch_impl(std::size_t start_position, int delta) const { + if (start_position == 0 || start_position > bit_count_) { + return npos; + } + + const std::size_t block_index = (start_position - 1) / kBlockBits; + const std::size_t block_begin = block_index * kBlockBits; + const std::size_t end_offset = start_position - block_begin; + const std::size_t block_result = + find_bwd_in_block(block_index, end_offset, delta); + if (block_result != npos) { + return block_result; + } + + const std::int64_t relative_target = + block_excess_at_local(block_index, end_offset) + delta; + const std::int64_t block_start_excess = prefix_excess_impl(block_begin); + const std::int64_t target = block_start_excess + relative_target; + std::size_t level = 0; + std::size_t index = block_index; + while (has_parent_level(level)) { + const std::size_t fanout = fanout_to_parent(level); + const std::size_t parent = index / fanout; + const NodeScanResult scan = + scan_children_bwd(level, parent, 0, index % fanout, target, + prefix_excess_impl(node_start_bit(level, index))); + if (scan.found) { + if (scan.boundary_only) { + return node_start_bit(level, scan.index); + } + const std::size_t result = + descend_bwd(level, scan.index, target, scan.node_start_excess); + if (result != npos) { + return result; + } + } + level += 1; + index = parent; + } + return npos; + } + + std::size_t range_min_query_pos_impl(std::size_t range_begin, + std::size_t range_end) const { + if (range_begin > range_end || range_end >= bit_count_) { + return npos; + } + return range_extreme_query_pos(range_begin, range_end, true); + } + + int range_min_query_val_impl(std::size_t range_begin, + std::size_t range_end) const { + if (range_begin > range_end || range_end >= bit_count_) { + return 0; + } + return range_extreme_query_val(range_begin, range_end, true); + } + + std::size_t range_max_query_pos_impl(std::size_t range_begin, + std::size_t range_end) const { + if (range_begin > range_end || range_end >= bit_count_) { + return npos; + } + return range_extreme_query_pos(range_begin, range_end, false); + } + + int range_max_query_val_impl(std::size_t range_begin, + std::size_t range_end) const { + if (range_begin > range_end || range_end >= bit_count_) { + return 0; + } + return range_extreme_query_val(range_begin, range_end, false); + } + + std::size_t mincount_impl(std::size_t range_begin, + std::size_t range_end) const { + if (range_begin > range_end || range_end >= bit_count_) { + return 0; + } + return range_min_stats(range_begin, range_end).count; + } + + std::size_t minselect_impl(std::size_t range_begin, + std::size_t range_end, + std::size_t rank) const { + if (range_begin > range_end || range_end >= bit_count_ || rank == 0) { + return npos; + } + const RangeMinStats stats = range_min_stats(range_begin, range_end); + if (rank > stats.count) { + return npos; + } + return range_min_select(range_begin, range_end, stats.value, rank); + } + + std::size_t close_impl(std::size_t open_position) const { + if (open_position >= bit_count_) { + return npos; + } + if (!bit(open_position)) { + return open_position; + } + return fwd_excess_at(open_position, -1); + } + + std::size_t open_impl(std::size_t close_position) const { + if (close_position >= bit_count_) { + return npos; + } + if (bit(close_position)) { + return close_position; + } + return bwdsearch_impl(close_position + 1, 0); + } + + std::size_t enclose_impl(std::size_t position) const { + if (position >= bit_count_) { + return npos; + } + if (!bit(position)) { + return open_impl(position); + } + return bwdsearch_impl(position + 1, -2); + } + + private: + /** + * @brief Search forward using SDSL-style inclusive-position excess semantics. + * @details Public `fwdsearch_impl` starts from a prefix boundary. This helper + * starts at @p position as an included bit position, so the equivalent btree + * boundary search begins at `position + 1`. It is used by + * balanced-parentheses operations such as `close`, where SDSL's + * `fwd_excess(i, -1)` maps to this wrapper. + * @param position Included bit position used as the search origin. + * @param delta Desired excess delta from the prefix after @p position. + * @return Matching bit position, or `npos` when no forward match exists. + */ + std::size_t fwd_excess_at(std::size_t position, int delta) const { + if (position >= bit_count_) { + return npos; + } + if (position + 1 >= bit_count_) { + return npos; + } + return fwdsearch_impl(position + 1, delta); + } + + struct Summary { + std::uint64_t size_bits = 0; + std::uint64_t ones = 0; + std::int64_t block_excess = 0; + std::int64_t min_excess = 0; + std::int64_t max_excess = 0; + std::uint64_t min_count = 0; + }; + + template + struct alignas(kCacheLineBytes) SummaryNode { + using ExcessType = Excess; + using CountType = Count; + static constexpr std::size_t kFanout = Fanout; + std::array prefix_excess{}; + std::array min_excess{}; + std::array max_excess{}; + std::array min_count{}; + }; + + using LowNode = SummaryNode; + using HighNode = SummaryNode; + static_assert(alignof(LowNode) == kCacheLineBytes); + static_assert(alignof(HighNode) == kCacheLineBytes); + static_assert(sizeof(LowNode) % kCacheLineBytes == 0); + static_assert(sizeof(HighNode) % kCacheLineBytes == 0); + + struct ByteAgg { + std::int8_t block_excess = 0; + std::int8_t min_excess = 0; + std::int8_t max_excess = 0; + std::uint8_t min_count = 0; + std::uint8_t pos_first_min = 0; + std::uint8_t pos_first_max = 0; + }; + + static constexpr std::size_t kSearchChunkBits = 128; + static constexpr std::size_t kSearchChunkWords = kSearchChunkBits / 64; + static constexpr std::size_t kSearchChunkCount = + kBlockBits / kSearchChunkBits; + + /** + * @brief Byte-level summaries used by range scans. + * @details Builds a process-local lookup table indexed by byte value. Each + * entry records the byte excess, local min/max excess, minimum multiplicity, + * and first positions of the min/max values. + * @return Immutable 256-entry byte summary table. + */ + static const std::array& byte_lut() { + static const std::array table = [] { + std::array result{}; + for (int byte_value = 0; byte_value < 256; ++byte_value) { + ByteAgg agg; + int current = 0; + int minimum = std::numeric_limits::max(); + int maximum = std::numeric_limits::min(); + const auto bit_at = [&](int bit_index) { + return (byte_value >> bit_index) & 1; + }; + for (int bit_index = 0; bit_index < 8; ++bit_index) { + const int value = bit_at(bit_index); + current += value ? 1 : -1; + if (current < minimum) { + minimum = current; + agg.min_count = 1; + agg.pos_first_min = static_cast(bit_index); + } else if (current == minimum) { + ++agg.min_count; + } + if (current > maximum) { + maximum = current; + agg.pos_first_max = static_cast(bit_index); + } + } + agg.block_excess = static_cast(current); + agg.min_excess = static_cast(minimum); + agg.max_excess = static_cast(maximum); + result[static_cast(byte_value)] = agg; + } + return result; + }(); + return table; + } + + /** + * @brief Build the rank index and min-max tree levels over the input bits. + * @details Validates that @p words contains enough storage, records the + * external span, constructs the internal rank/select helper, summarizes each + * 512-bit block, and builds summary levels above the blocks. + * @param words External little-endian words that remain owned by the caller. + * @param bit_count Number of valid bits in @p words. + * @throws std::invalid_argument if @p words is shorter than required for + * @p bit_count. + */ + void build(std::span words, std::size_t bit_count) { + const std::size_t required_words = (bit_count + 63) / 64; + if (words.size() < required_words) { + throw std::invalid_argument( + "RmMBTree input span is shorter than bit_count"); + } + + bits_ = words; + bit_count_ = bit_count; + rank_index_.emplace(words, bit_count); + block_count_ = (bit_count_ + kBlockBits - 1) / kBlockBits; + std::vector block_summaries(block_count_); + + for (std::size_t block_index = 0; block_index < block_count_; + ++block_index) { + const std::size_t block_begin = block_index * kBlockBits; + const std::size_t block_end = + std::min(bit_count_, block_begin + kBlockBits); + block_summaries[block_index] = + summarize_bits(block_begin, block_end - block_begin); + } + + build_levels(block_summaries); + } + + /** + * @brief Build low and high summary levels from 512-bit block summaries. + * @details Level 0 is the block-summary level. The first parent level uses + * compact 16-bit low nodes, and all higher levels use 64-bit high nodes until + * a single top summary remains. + * @param block_summaries Relative summaries for each 512-bit block. + */ + void build_levels(const std::vector& block_summaries) { + level_counts_.clear(); + low_levels_.clear(); + high_levels_.clear(); + top_summary_ = Summary{}; + level_counts_.push_back(block_summaries.size()); + if (block_summaries.empty()) { + return; + } + + std::vector current = block_summaries; + current = build_parent_level(current, low_levels_.emplace_back()); + level_counts_.push_back(current.size()); + current = + build_parent_level(current, high_levels_.emplace_back()); + level_counts_.push_back(current.size()); + while (current.size() > 1) { + current = + build_parent_level(current, high_levels_.emplace_back()); + level_counts_.push_back(current.size()); + } + top_summary_ = current.front(); + } + + /** + * @brief Build one parent level and return summaries for the produced nodes. + * @details Groups @p in by the fanout of @p Node, stores each child summary + * in the parent node, and returns relative summaries for the newly built + * parent level. + * @tparam Node LowNode or HighNode. + * @param in Child summaries from the previous level. + * @param nodes Destination storage for the constructed parent nodes. + * @return Relative summaries for @p nodes. + */ + template + static std::vector build_parent_level(const std::vector& in, + std::vector& nodes) { + constexpr std::size_t fanout = Node::kFanout; + std::vector out((in.size() + fanout - 1) / fanout); + nodes.resize(out.size()); + for (std::size_t parent = 0; parent < out.size(); ++parent) { + const std::size_t begin = parent * fanout; + const std::size_t end = std::min(in.size(), begin + fanout); + Summary combined; + for (std::size_t i = begin; i < end; ++i) { + store_child_summary(nodes[parent], i - begin, combined.block_excess, + in[i]); + combined = append(combined, in[i]); + } + out[parent] = combined; + } + return out; + } + + /** + * @brief Store one child summary in a node using inclusive prefix excess. + * @details Stores cumulative excess through @p slot in + * `prefix_excess[slot]`, along with the child min/max excess and minimum + * count. The child total excess is later reconstructed from adjacent + * cumulative prefixes. + * @tparam Node LowNode or HighNode. + * @param node Parent node being filled. + * @param slot Child slot inside @p node. + * @param prefix_excess Excess before this child within @p node. + * @param summary Relative summary of the child subtree. + */ + template + static void store_child_summary(Node& node, + std::size_t slot, + std::int64_t prefix_excess, + const Summary& summary) { + node.prefix_excess[slot] = + static_cast( + prefix_excess + summary.block_excess); + node.min_excess[slot] = + static_cast( + summary.min_excess); + node.max_excess[slot] = + static_cast( + summary.max_excess); + node.min_count[slot] = + static_cast( + summary.min_count); + } + + /** + * @brief Summarize a contiguous bit range relative to its beginning. + * @details Scans the range bit by bit and computes total excess, min/max + * relative excess, number of positions attaining the minimum, and number of 1 + * bits. + * @param begin First bit position. + * @param length Number of bits to summarize. + * @return Relative summary of `[begin, begin + length)`. + */ + Summary summarize_bits(std::size_t begin, std::size_t length) const { + Summary summary; + summary.size_bits = length; + if (length == 0) { + return summary; + } + int current = 0; + int minimum = std::numeric_limits::max(); + int maximum = std::numeric_limits::min(); + for (std::size_t offset = 0; offset < length; ++offset) { + const std::uint8_t value = bit(begin + offset); + summary.ones += value; + current += value ? 1 : -1; + if (current < minimum) { + minimum = current; + summary.min_count = 1; + } else if (current == minimum) { + ++summary.min_count; + } + if (current > maximum) { + maximum = current; + } + } + summary.block_excess = current; + summary.min_excess = minimum; + summary.max_excess = maximum; + return summary; + } + + /** + * @brief Concatenate two relative summaries. + * @details Produces the summary for `left || right`, translating right-side + * extrema by the total excess of @p left and merging minimum counts when both + * sides attain the combined minimum. + * @param left Summary of the first range. + * @param right Summary of the following range. + * @return Relative summary of the concatenation. + */ + static Summary append(Summary left, const Summary& right) { + if (left.size_bits == 0) { + return right; + } + if (right.size_bits == 0) { + return left; + } + Summary result; + result.size_bits = left.size_bits + right.size_bits; + result.ones = left.ones + right.ones; + result.block_excess = left.block_excess + right.block_excess; + result.min_excess = + std::min(left.min_excess, left.block_excess + right.min_excess); + result.max_excess = + std::max(left.max_excess, left.block_excess + right.max_excess); + result.min_count = 0; + if (left.min_excess == result.min_excess) { + result.min_count += left.min_count; + } + if (left.block_excess + right.min_excess == result.min_excess) { + result.min_count += right.min_count; + } + return result; + } + + /** + * @brief Number of valid bits in a 512-bit block. + * @details All blocks except the last have `kBlockBits` bits. The last block + * may be partial when `size()` is not a multiple of 512. + * @param block_index Zero-based block index. + * @return Number of valid bits in the block. + */ + std::size_t block_size(std::size_t block_index) const { + const std::size_t begin = block_index * kBlockBits; + return std::min(bit_count_ - begin, kBlockBits); + } + + /** + * @brief Whether a block is full and backed by all eight storage words. + * @details Fast SIMD/chunk search paths require a full 512-bit block and all + * eight words to be present in the external span. + * @param block_index Zero-based block index. + * @return True when the block can be processed as exactly eight words. + */ + bool full_block_has_words(std::size_t block_index) const { + return block_size(block_index) == kBlockBits && + (block_index + 1) * kBlockWords <= bits_.size(); + } + + /** + * @brief Prefix excess at a block-local boundary using direct popcounts. + * @details Reads only the words in the target block instead of using the rank + * index. This is used on local-search miss paths before ascending the tree. + * @param block_index Zero-based block index. + * @param offset Exclusive block-local prefix boundary, clamped to the block + * length. + * @return Prefix excess relative to the beginning of the block. + */ + std::int64_t block_excess_at_local(std::size_t block_index, + std::size_t offset) const { + const std::size_t length = block_size(block_index); + offset = std::min(offset, length); + if (offset == 0) { + return 0; + } + + const std::size_t first_word = block_index * kBlockWords; + std::size_t remaining = offset; + std::int64_t ones = 0; + std::size_t word_offset = 0; + while (remaining >= 64) { + ones += std::popcount(bits_[first_word + word_offset]); + remaining -= 64; + ++word_offset; + } + if (remaining != 0) { + ones += std::popcount(bits_[first_word + word_offset] & + first_bits_mask(remaining)); + } + return 2 * ones - static_cast(offset); + } + + /** + * @brief Total excess of a 128-bit search chunk. + * @details Computes `2 * popcount(chunk) - 128` for the two input words. + * @param chunk Pointer to two little-endian words. + * @return Total excess of the 128-bit chunk. + */ + static int chunk_excess_128(const std::uint64_t* chunk) { + return 2 * static_cast(std::popcount(chunk[0]) + + std::popcount(chunk[1])) - + static_cast(kSearchChunkBits); + } + + /** + * @brief Find a forward excess target inside one block from a local boundary. + * @details Searches for the first bit position `p >= start_offset` in the + * block such that the prefix excess at `p + 1` equals the excess at + * `start_offset` plus @p delta. Full blocks use 128-bit chunk primitives; + * partial blocks fall back to direct scalar scanning. + * @param block_index Zero-based block index. + * @param start_offset Block-local start boundary. + * @param delta Desired excess delta from `start_offset`. + * @return Global bit position of the match, or `npos`. + */ + std::size_t find_fwd_in_block(std::size_t block_index, + std::size_t start_offset, + std::int64_t delta) const { + const std::size_t length = block_size(block_index); + if (start_offset >= length) { + return npos; + } + + if (full_block_has_words(block_index)) { + const std::uint64_t* block = bits_.data() + block_index * kBlockWords; + const std::size_t first_chunk = start_offset / kSearchChunkBits; + std::int64_t target = delta; + for (std::size_t chunk = first_chunk; chunk < kSearchChunkCount; + ++chunk) { + const std::uint64_t* chunk_words = block + chunk * kSearchChunkWords; + const std::size_t local_start = + chunk == first_chunk ? start_offset - chunk * kSearchChunkBits : 0; + if (chunk == first_chunk) { + target += prefix_excess_128(chunk_words, local_start); + } + int block_excess = 0; + const std::size_t offset = forward_search_128( + chunk_words, static_cast(target), local_start, &block_excess); + if (offset != kSearchChunkBits) { + return block_index * kBlockBits + chunk * kSearchChunkBits + offset; + } + target -= block_excess; + } + return npos; + } + + std::int64_t current = block_excess_at_local(block_index, start_offset); + const std::int64_t relative_target = current + delta; + const std::size_t block_begin = block_index * kBlockBits; + for (std::size_t offset = start_offset; offset < length; ++offset) { + current += bit(block_begin + offset) ? 1 : -1; + if (current == relative_target) { + return block_index * kBlockBits + offset; + } + } + return npos; + } + + /** + * @brief Find a backward excess target inside one block from a local + * boundary. + * @details Searches for the greatest prefix boundary `p < end_offset` in the + * block such that the prefix excess at `p` equals the excess at @p end_offset + * plus @p delta. The returned value is a global prefix-boundary/bit position. + * @param block_index Zero-based block index. + * @param end_offset Exclusive block-local boundary to search before. + * @param delta Desired excess delta from `end_offset`. + * @return Global position of the matching boundary, or `npos`. + */ + std::size_t find_bwd_in_block(std::size_t block_index, + std::size_t end_offset, + std::int64_t delta) const { + if (end_offset == 0) { + return npos; + } + const std::size_t max_prefix_length = end_offset - 1; + + if (full_block_has_words(block_index)) { + const std::uint64_t* block = bits_.data() + block_index * kBlockWords; + const std::size_t last_chunk = max_prefix_length / kSearchChunkBits; + std::int64_t target = delta; + for (std::size_t chunk = last_chunk + 1; chunk > 0;) { + --chunk; + const std::uint64_t* chunk_words = block + chunk * kSearchChunkWords; + const std::size_t local_end = + chunk == last_chunk ? end_offset - chunk * kSearchChunkBits + : kSearchChunkBits; + if (chunk == last_chunk) { + target += prefix_excess_128(chunk_words, local_end); + } + int block_excess = 0; + const std::size_t offset = backward_search_128( + chunk_words, static_cast(target), local_end, &block_excess); + if (offset != kSearchChunkBits) { + return block_index * kBlockBits + chunk * kSearchChunkBits + offset; + } + if (chunk > 0) { + target += chunk_excess_128(block + (chunk - 1) * kSearchChunkWords); + } + } + return npos; + } + + const std::int64_t relative_target = + block_excess_at_local(block_index, end_offset) + delta; + std::int64_t current = + block_excess_at_local(block_index, max_prefix_length); + const std::size_t block_begin = block_index * kBlockBits; + for (std::size_t prefix_length = max_prefix_length; prefix_length > 0; + --prefix_length) { + if (current == relative_target) { + return block_index * kBlockBits + prefix_length; + } + current -= bit(block_begin + prefix_length - 1) ? 1 : -1; + } + return relative_target == 0 ? block_index * kBlockBits : npos; + } + + /** + * @brief Descend from a matching summary node to the first forward match. + * @details Starting from a node whose min/max range contains @p target, + * repeatedly chooses the leftmost matching child and finally scans the leaf + * block. + * @param level Summary level of the starting node. + * @param index Node index at @p level. + * @param target Absolute prefix excess target. + * @param node_start_excess Absolute excess before the starting node. + * @return Global bit position of the forward match, or `npos`. + */ + std::size_t descend_fwd(std::size_t level, + std::size_t index, + std::int64_t target, + std::int64_t node_start_excess) const { + while (level > 0) { + const std::size_t child_level = level - 1; + const std::size_t fanout = fanout_to_parent(child_level); + const std::size_t child_begin = index * fanout; + const std::size_t child_end = + std::min(level_count(child_level), child_begin + fanout); + const NodeScanResult scan = + scan_children_fwd(child_level, index, 0, child_end - child_begin, + target, node_start_excess); + if (!scan.found) { + return npos; + } + index = scan.index; + level = child_level; + node_start_excess = scan.node_start_excess; + } + return find_fwd_in_block(index, 0, target - node_start_excess); + } + + /** + * @brief Descend from a matching summary node to the last backward match. + * @details Starting from a node whose min/max range can contain @p target, + * repeatedly chooses the rightmost matching child and finally scans the leaf + * block backward. Boundary-only matches return the child start directly. + * @param level Summary level of the starting node. + * @param index Node index at @p level. + * @param target Absolute prefix excess target. + * @param node_start_excess Absolute excess before the starting node. + * @return Global position of the backward match, or `npos`. + */ + std::size_t descend_bwd(std::size_t level, + std::size_t index, + std::int64_t target, + std::int64_t node_start_excess) const { + while (level > 0) { + const std::size_t child_level = level - 1; + const std::size_t fanout = fanout_to_parent(child_level); + const std::size_t child_begin = index * fanout; + const std::size_t child_end = + std::min(level_count(child_level), child_begin + fanout); + std::int64_t child_start_excess = + node_start_excess + summary_at(level, index).block_excess; + const NodeScanResult scan = + scan_children_bwd(child_level, index, 0, child_end - child_begin, + target, child_start_excess); + if (!scan.found) { + return npos; + } + if (scan.boundary_only) { + return node_start_bit(child_level, scan.index); + } + index = scan.index; + level = child_level; + node_start_excess = scan.node_start_excess; + } + const std::int64_t block_excess = summary_at(0, index).block_excess; + return find_bwd_in_block(index, block_size(index), + target - node_start_excess - block_excess); + } + + struct NodeScanResult { + bool found = false; + bool boundary_only = false; + std::size_t index = 0; + std::int64_t node_start_excess = 0; + }; + + /** + * @brief Scan child summaries in increasing order for a forward match. + * @details Dispatches to either low-node or high-node scanning depending on + * @p child_level. The scan interval is `[begin_slot, end_slot)`. + * @param child_level Level of the children being scanned. + * @param parent Parent node index. + * @param begin_slot First child slot to inspect. + * @param end_slot One-past-last child slot to inspect. + * @param target Absolute prefix excess target. + * @param begin_excess Absolute excess at `begin_slot`. + * @return Scan result describing the first matching child, if any. + */ + NodeScanResult scan_children_fwd(std::size_t child_level, + std::size_t parent, + std::size_t begin_slot, + std::size_t end_slot, + std::int64_t target, + std::int64_t begin_excess) const { + if (begin_slot >= end_slot) { + return {}; + } + if (child_level == 0) { + return scan_node_fwd(low_levels_[0][parent], child_level, parent, + begin_slot, end_slot, target, begin_excess); + } + return scan_node_fwd(high_levels_[child_level - 1][parent], child_level, + parent, begin_slot, end_slot, target, begin_excess); + } + + /** + * @brief Scan child summaries in decreasing order for a backward match. + * @details Dispatches to either low-node or high-node scanning depending on + * @p child_level. The scan interval is `[begin_slot, end_slot)`, inspected + * from right to left. + * @param child_level Level of the children being scanned. + * @param parent Parent node index. + * @param begin_slot First child slot in the scan interval. + * @param end_slot One-past-last child slot in the scan interval. + * @param target Absolute prefix excess target. + * @param end_excess Absolute excess at `end_slot`. + * @return Scan result describing the last matching child, if any. + */ + NodeScanResult scan_children_bwd(std::size_t child_level, + std::size_t parent, + std::size_t begin_slot, + std::size_t end_slot, + std::int64_t target, + std::int64_t end_excess) const { + if (begin_slot >= end_slot) { + return {}; + } + if (child_level == 0) { + return scan_node_bwd(low_levels_[0][parent], child_level, parent, + begin_slot, end_slot, target, end_excess); + } + return scan_node_bwd(high_levels_[child_level - 1][parent], child_level, + parent, begin_slot, end_slot, target, end_excess); + } + + /** + * @brief Scan one concrete node type in increasing slot order. + * @details Converts the absolute target into each child's relative target and + * checks the child min/max range in SIMD-sized groups. + * @tparam Node LowNode or HighNode. + * @param node Parent node being scanned. + * @param child_level Level of @p node's children. + * @param parent Parent node index. + * @param begin_slot First child slot to inspect. + * @param end_slot One-past-last child slot to inspect. + * @param target Absolute prefix excess target. + * @param begin_excess Absolute excess at `begin_slot`. + * @return First matching child, if any. + */ + template + NodeScanResult scan_node_fwd(const Node& node, + std::size_t child_level, + std::size_t parent, + std::size_t begin_slot, + std::size_t end_slot, + std::int64_t target, + std::int64_t begin_excess) const { + const std::int64_t node_base_excess = + begin_excess - prefix_excess_at(node, begin_slot); + for (std::size_t slot = begin_slot; slot < end_slot;) { + const std::size_t lane_count = + std::min(vector_lane_count(), end_slot - slot); + const std::uint32_t mask = matching_chunk_mask( + node, slot, lane_count, target - node_base_excess, false); + if (mask != 0) { + const std::size_t lane = std::countr_zero(mask); + const std::size_t matched_slot = slot + lane; + return {true, false, + parent * fanout_to_parent(child_level) + matched_slot, + child_start_excess(node, node_base_excess, matched_slot)}; + } + slot += lane_count; + } + return {}; + } + + /** + * @brief Scan one concrete node type in decreasing slot order. + * @details Converts the absolute target into each child's relative target and + * checks the child min/max range in SIMD-sized groups from right to left. + * @tparam Node LowNode or HighNode. + * @param node Parent node being scanned. + * @param child_level Level of @p node's children. + * @param parent Parent node index. + * @param begin_slot First child slot in the scan interval. + * @param end_slot One-past-last child slot in the scan interval. + * @param target Absolute prefix excess target. + * @param end_excess Absolute excess at `end_slot`. + * @return Last matching child, if any. + */ + template + NodeScanResult scan_node_bwd(const Node& node, + std::size_t child_level, + std::size_t parent, + std::size_t begin_slot, + std::size_t end_slot, + std::int64_t target, + std::int64_t end_excess) const { + const std::int64_t node_base_excess = + end_excess - prefix_excess_at(node, end_slot); + for (std::size_t slot_end = end_slot; slot_end > begin_slot;) { + const std::size_t lane_count = + std::min(vector_lane_count(), slot_end - begin_slot); + const std::size_t slot = slot_end - lane_count; + const std::uint32_t mask = matching_chunk_mask( + node, slot, lane_count, target - node_base_excess, true); + if (mask != 0) { + const std::size_t lane = + static_cast(std::bit_width(mask) - 1); + const std::size_t matched_slot = slot + lane; + const std::int64_t relative_target = + target - child_start_excess(node, node_base_excess, matched_slot); + const bool interior_match = + node.min_excess[matched_slot] <= relative_target && + relative_target <= node.max_excess[matched_slot]; + return {true, !interior_match, + parent * fanout_to_parent(child_level) + matched_slot, + child_start_excess(node, node_base_excess, matched_slot)}; + } + slot_end = slot; + } + return {}; + } + + /** + * @brief Inclusive prefix excess before child slot. + * @details Node storage keeps inclusive cumulative excess through each child. + * This helper converts that representation to the exclusive prefix before a + * slot. + * @tparam Node LowNode or HighNode. + * @param node Node containing cumulative prefixes. + * @param slot Child slot boundary in `[0, Node::kFanout]`. + * @return Excess before @p slot within @p node. + */ + template + static std::int64_t prefix_excess_at(const Node& node, std::size_t slot) { + if (slot == 0) { + return 0; + } + return prefix_through(node, slot - 1); + } + + /** + * @brief Absolute excess at the start of a child slot. + * @details Adds the node's absolute base excess to the child-local exclusive + * prefix before @p slot. + * @tparam Node LowNode or HighNode. + * @param node Node containing cumulative prefixes. + * @param node_base_excess Absolute excess before child slot 0. + * @param slot Child slot. + * @return Absolute excess before the child at @p slot. + */ + template + static std::int64_t child_start_excess(const Node& node, + std::int64_t node_base_excess, + std::size_t slot) { + return node_base_excess + prefix_excess_at(node, slot); + } + + /** + * @brief Inclusive prefix excess through child slot. + * @details Reads the stored cumulative prefix value for a child slot. + * @tparam Node LowNode or HighNode. + * @param node Node containing cumulative prefixes. + * @param slot Child slot to read. + * @return Excess through @p slot within @p node. + */ + template + static std::int64_t prefix_through(const Node& node, std::size_t slot) { + return static_cast(node.prefix_excess[slot]); + } + + /** + * @brief Total excess of one child reconstructed from inclusive prefixes. + * @details Computes the difference between the inclusive prefix through + * @p slot and the exclusive prefix before @p slot. + * @tparam Node LowNode or HighNode. + * @param node Node containing cumulative prefixes. + * @param slot Child slot. + * @return Total excess of the child subtree at @p slot. + */ + template + static std::int64_t child_excess(const Node& node, std::size_t slot) { + return prefix_through(node, slot) - prefix_excess_at(node, slot); + } + + /** + * @brief SIMD lane count used for one node scan chunk. + * @details Low nodes use 16 int16 lanes, while high nodes use 4 int64 lanes. + * @tparam Node LowNode or HighNode. + * @return Number of child slots processed per vector chunk. + */ + template + static constexpr std::size_t vector_lane_count() { + if constexpr (std::is_same_v) { + return 16; + } else { + return 4; + } + } + + /** + * @brief Return a bit mask of child lanes whose min/max range can match. + * @details Uses AVX2 specializations when the lane count exactly matches the + * node type; otherwise falls back to scalar range checks. + * @tparam Node LowNode or HighNode. + * @param node Node whose child ranges are checked. + * @param slot First child slot represented by lane 0. + * @param lane_count Number of valid lanes to check. + * @param target_in_node Target excess relative to the start of @p node. + * @param include_zero_boundary Whether a relative target of zero is accepted + * as a boundary-only match for backward search. + * @return Bit mask with bit `i` set when lane `i` can match. + */ + template + static std::uint32_t matching_chunk_mask(const Node& node, + std::size_t slot, + std::size_t lane_count, + std::int64_t target_in_node, + bool include_zero_boundary) { +#ifdef PIXIE_AVX2_SUPPORT + if constexpr (std::is_same_v) { + if (lane_count == 16 && + target_in_node >= std::numeric_limits::min() && + target_in_node <= std::numeric_limits::max()) { + alignas(32) std::int16_t prefix_before[16]{}; + fill_prefix_before(node, slot, prefix_before); + return rmm_btree_match_mask_i16x16( + prefix_before, node.min_excess.data() + slot, + node.max_excess.data() + slot, + static_cast(target_in_node), include_zero_boundary); + } + } else if constexpr (std::is_same_v) { + if (lane_count == 4) { + alignas(32) std::int64_t prefix_before[4]{}; + fill_prefix_before(node, slot, prefix_before); + return rmm_btree_match_mask_i64x4( + prefix_before, node.min_excess.data() + slot, + node.max_excess.data() + slot, target_in_node, + include_zero_boundary); + } + } +#endif + std::uint32_t result = 0; + for (std::size_t lane = 0; lane < lane_count; ++lane) { + const std::int64_t rel = + target_in_node - prefix_excess_at(node, slot + lane); + const bool found = (include_zero_boundary && rel == 0) || + (node.min_excess[slot + lane] <= rel && + rel <= node.max_excess[slot + lane]); + if (found) { + result |= std::uint32_t{1} << lane; + } + } + return result; + } + + /** + * @brief Fill SIMD lanes with child-start prefix excess values. + * @details Summary nodes store inclusive prefix excess through each child. + * This helper converts that representation into exclusive prefix-before-child + * values for the vector chunk beginning at @p slot. Lane zero is zero when + * the chunk starts at the first child; otherwise each lane reads the previous + * stored inclusive prefix. + * @tparam Node LowNode or HighNode. + * @param node Summary node containing inclusive child prefixes. + * @param slot First child slot represented by output lane zero. + * @param out Destination array with `vector_lane_count()` entries. + */ + template + static void fill_prefix_before(const Node& node, + std::size_t slot, + typename Node::ExcessType* out) { + if (slot == 0) { + out[0] = 0; + for (std::size_t lane = 1; lane < vector_lane_count(); ++lane) { + out[lane] = node.prefix_excess[lane - 1]; + } + return; + } + for (std::size_t lane = 0; lane < vector_lane_count(); ++lane) { + out[lane] = node.prefix_excess[slot + lane - 1]; + } + } + + /** + * @brief Whether a summary can contain a forward target. + * @details Forward search needs an interior prefix whose relative excess lies + * in the child's `[min_excess, max_excess]` interval. + * @param summary Child summary to test. + * @param relative_target Target excess relative to the child start. + * @return True if the target may occur in the child. + */ + static bool contains_fwd(const Summary& summary, + std::int64_t relative_target) { + return summary.min_excess <= relative_target && + relative_target <= summary.max_excess; + } + + /** + * @brief Whether a summary can contain a backward target or left boundary. + * @details Backward search accepts a relative target of zero as a match at + * the child boundary; otherwise it uses the same min/max containment as + * forward search. + * @param summary Child summary to test. + * @param relative_target Target excess relative to the child start. + * @return True if the target may occur in or at the child boundary. + */ + static bool contains_bwd(const Summary& summary, + std::int64_t relative_target) { + return relative_target == 0 || contains_fwd(summary, relative_target); + } + + std::int64_t prefix_excess_impl(std::size_t end_position) const { + return 2 * static_cast(rank1_impl(end_position)) - + static_cast(end_position); + } + + /** + * @brief Whether a level has a non-empty parent level. + * @details Used while ascending the btree from a block toward the root. + * @param level Current summary level. + * @return True when `level + 1` exists and has at least one node. + */ + bool has_parent_level(std::size_t level) const { + return level + 1 < total_levels() && level_count(level + 1) != 0; + } + + /** + * @brief Number of summary levels including the block level. + * @details Level 0 corresponds to 512-bit blocks; higher levels are stored in + * low/high summary-node arrays. + * @return Number of levels currently represented. + */ + std::size_t total_levels() const { return level_counts_.size(); } + + /** + * @brief Number of nodes at a summary level. + * @details Returns zero for out-of-range levels. + * @param level Summary level to query. + * @return Node count at @p level. + */ + std::size_t level_count(std::size_t level) const { + return level < level_counts_.size() ? level_counts_[level] : 0; + } + + /** + * @brief Reconstruct the relative summary for one node or block. + * @details Retrieves a child summary from its parent node. For the top level, + * returns the cached top summary. + * @param level Summary level of the requested node. + * @param index Node index within @p level. + * @return Relative summary of the requested subtree. + */ + Summary summary_at(std::size_t level, std::size_t index) const { + if (level + 1 >= total_levels()) { + return top_summary_; + } + const std::size_t parent_level = level + 1; + const std::size_t fanout = fanout_to_parent(level); + const std::size_t parent = index / fanout; + const std::size_t slot = index % fanout; + Summary summary; + if (parent_level == 1) { + const LowNode& node = low_levels_[0][parent]; + summary.block_excess = child_excess(node, slot); + summary.min_excess = node.min_excess[slot]; + summary.max_excess = node.max_excess[slot]; + summary.min_count = node.min_count[slot]; + } else { + const HighNode& node = high_levels_[parent_level - 2][parent]; + summary.block_excess = child_excess(node, slot); + summary.min_excess = node.min_excess[slot]; + summary.max_excess = node.max_excess[slot]; + summary.min_count = node.min_count[slot]; + } + return summary; + } + + /** + * @brief Fanout from a level to its parent level. + * @details Blocks are grouped by `kLowFanout` into low nodes; all higher + * levels use `kHighFanout`. + * @param level Child level. + * @return Fanout used to compute the parent of @p level. + */ + static std::size_t fanout_to_parent(std::size_t level) { + return level == 0 ? kLowFanout : kHighFanout; + } + + /** + * @brief Multiply sizes, saturating on overflow. + * @details Returns `max_size_t` instead of overflowing when the product + * cannot be represented. + * @param lhs Left operand. + * @param rhs Right operand. + * @return Saturating product. + */ + static std::size_t mul_clamped(std::size_t lhs, std::size_t rhs) { + if (lhs != 0 && rhs > std::numeric_limits::max() / lhs) { + return std::numeric_limits::max(); + } + return lhs * rhs; + } + + /** + * @brief Maximum number of bits covered by one node at a level. + * @details Level 0 spans one 512-bit block. Higher levels multiply by the low + * fanout once and then by the high fanout for each additional level. + * @param level Summary level. + * @return Maximum bit span of one node at @p level, saturated on overflow. + */ + static std::size_t level_span_bits(std::size_t level) { + std::size_t span = kBlockBits; + if (level >= 1) { + span = mul_clamped(span, kLowFanout); + } + if (level >= 2) { + for (std::size_t i = 2; i <= level; ++i) { + span = mul_clamped(span, kHighFanout); + } + } + return span; + } + + /** + * @brief First bit position covered by a node. + * @details Computes `index * level_span_bits(level)` and clamps the result to + * the sequence size on overflow or beyond-end nodes. + * @param level Summary level. + * @param index Node index within @p level. + * @return First covered bit position, clamped to `bit_count_`. + */ + std::size_t node_start_bit(std::size_t level, std::size_t index) const { + const std::size_t span = level_span_bits(level); + if (span != 0 && index > std::numeric_limits::max() / span) { + return bit_count_; + } + return std::min(bit_count_, index * span); + } + + /** + * @brief Number of valid bits covered by a node. + * @details Returns the full level span for interior nodes and the remaining + * sequence length for the last partial node. + * @param level Summary level. + * @param index Node index within @p level. + * @return Number of valid bits covered by the node. + */ + std::size_t node_size_bits(std::size_t level, std::size_t index) const { + const std::size_t start = node_start_bit(level, index); + if (start >= bit_count_) { + return 0; + } + return std::min(level_span_bits(level), bit_count_ - start); + } + + /** + * @brief Exclusive end bit position covered by a node. + * @details Adds `node_size_bits(level, index)` to `node_start_bit(level, + * index)`, clamping to the sequence size on overflow. + * @param level Summary level. + * @param index Node index within @p level. + * @return Exclusive end position of the covered bit interval. + */ + std::size_t node_end_bit(std::size_t level, std::size_t index) const { + const std::size_t start = node_start_bit(level, index); + const std::uint64_t size = node_size_bits(level, index); + if (size > std::numeric_limits::max() - start) { + return bit_count_; + } + return std::min(bit_count_, start + static_cast(size)); + } + + struct NodeRef { + std::size_t level = 0; + std::size_t index = 0; + }; + + static constexpr std::size_t kMaxCoverNodes = 512; + + struct ScanResult { + std::int64_t block_excess = 0; + std::int64_t min_value = std::numeric_limits::max(); + std::int64_t max_value = std::numeric_limits::min(); + std::uint64_t min_count = 0; + std::size_t min_position = npos; + std::size_t max_position = npos; + }; + + struct RangeExtremeResult { + std::size_t position = npos; + std::int64_t value = 0; + }; + + struct RangeMinStats { + std::int64_t value = 0; + std::uint64_t count = 0; + }; + + struct Cover { + std::array nodes{}; + std::size_t size = 0; + + /** + * @brief Append a cover node if the fixed-capacity buffer has room. + * @details Cover construction has a conservative fixed upper bound. Extra + * pushes are ignored instead of growing storage. + * @param node Node reference to append. + */ + void push(NodeRef node) { + if (size < nodes.size()) { + nodes[size++] = node; + } + } + }; + + /** + * @brief Return the position of the first range minimum or maximum. + * @details Wrapper around `range_extreme_query` that extracts only the + * position. + * @param range_begin Inclusive range start. + * @param range_end Inclusive range end. + * @param find_min True for minimum, false for maximum. + * @return Position of the first extreme, or `npos`. + */ + std::size_t range_extreme_query_pos(std::size_t range_begin, + std::size_t range_end, + bool find_min) const { + return range_extreme_query(range_begin, range_end, find_min).position; + } + + /** + * @brief Return the value of the range minimum or maximum. + * @details Wrapper around `range_extreme_query` that extracts only the + * relative excess value. + * @param range_begin Inclusive range start. + * @param range_end Inclusive range end. + * @param find_min True for minimum, false for maximum. + * @return Extreme relative excess value. + */ + int range_extreme_query_val(std::size_t range_begin, + std::size_t range_end, + bool find_min) const { + return static_cast( + range_extreme_query(range_begin, range_end, find_min).value); + } + + /** + * @brief Compute both position and value for a range minimum or maximum. + * @details Scans partial edge blocks directly, covers the aligned middle with + * summary nodes, and descends only if the best extreme is represented by a + * summary node. + * @param range_begin Inclusive range start. + * @param range_end Inclusive range end. + * @param find_min True for minimum, false for maximum. + * @return Extreme position and value. + */ + RangeExtremeResult range_extreme_query(std::size_t range_begin, + std::size_t range_end, + bool find_min) const { + std::int64_t value = 0; + std::int64_t best = find_min ? std::numeric_limits::max() + : std::numeric_limits::min(); + std::size_t best_position = npos; + NodeRef best_node; + std::int64_t prefix_at_best_node = 0; + bool best_is_node = false; + + auto consider_point = [&](std::int64_t candidate, std::size_t position) { + if ((find_min && candidate < best) || (!find_min && candidate > best)) { + best = candidate; + best_position = position; + best_is_node = false; + } + }; + + const std::size_t range_end_exclusive = range_end + 1; + const std::size_t first_full_block = + (range_begin + kBlockBits - 1) / kBlockBits; + const std::size_t full_begin = + std::min(range_end_exclusive, first_full_block * kBlockBits); + if (range_begin < full_begin) { + const ScanResult scan = scan_range(range_begin, full_begin); + consider_point(find_min ? scan.min_value : scan.max_value, + find_min ? scan.min_position : scan.max_position); + value += scan.block_excess; + } + + const std::size_t last_full_block_exclusive = + range_end_exclusive / kBlockBits; + const std::size_t middle_begin = full_begin; + const std::size_t middle_end = + std::max(middle_begin, last_full_block_exclusive * kBlockBits); + if (middle_begin < middle_end) { + Cover cover; + collect_cover(middle_begin, middle_end, cover); + for (std::size_t i = 0; i < cover.size; ++i) { + const NodeRef& node = cover.nodes[i]; + Summary summary = summary_at(node.level, node.index); + const std::int64_t candidate = + value + (find_min ? summary.min_excess : summary.max_excess); + if ((find_min && candidate < best) || (!find_min && candidate > best)) { + best = candidate; + best_node = node; + prefix_at_best_node = value; + best_is_node = true; + } + value += summary.block_excess; + } + } + + if (middle_end < range_end_exclusive) { + const ScanResult scan = scan_range(middle_end, range_end_exclusive); + const std::int64_t candidate = + value + (find_min ? scan.min_value : scan.max_value); + consider_point(candidate, + find_min ? scan.min_position : scan.max_position); + } + + if (best_is_node) { + best_position = + descend_first_extreme(best_node.level, best_node.index, + best - prefix_at_best_node, find_min); + } + return {best_position, + best == std::numeric_limits::max() || + best == std::numeric_limits::min() + ? 0 + : best}; + } + + /** + * @brief Compute minimum value and multiplicity in an inclusive range. + * @details Uses the same edge-block plus aligned-cover decomposition as range + * extrema, accumulating the number of positions attaining the best minimum. + * @param range_begin Inclusive range start. + * @param range_end Inclusive range end. + * @return Minimum relative excess and its multiplicity. + */ + RangeMinStats range_min_stats(std::size_t range_begin, + std::size_t range_end) const { + std::int64_t value = 0; + std::int64_t best = std::numeric_limits::max(); + std::uint64_t count = 0; + + auto consider = [&](std::int64_t candidate, std::uint64_t candidate_count) { + if (candidate < best) { + best = candidate; + count = candidate_count; + } else if (candidate == best) { + count += candidate_count; + } + }; + + const std::size_t range_end_exclusive = range_end + 1; + const std::size_t first_full_block = + (range_begin + kBlockBits - 1) / kBlockBits; + const std::size_t full_begin = + std::min(range_end_exclusive, first_full_block * kBlockBits); + if (range_begin < full_begin) { + const ScanResult scan = scan_range(range_begin, full_begin); + consider(scan.min_value, scan.min_count); + value += scan.block_excess; + } + + const std::size_t last_full_block_exclusive = + range_end_exclusive / kBlockBits; + const std::size_t middle_begin = full_begin; + const std::size_t middle_end = + std::max(middle_begin, last_full_block_exclusive * kBlockBits); + if (middle_begin < middle_end) { + Cover cover; + collect_cover(middle_begin, middle_end, cover); + for (std::size_t i = 0; i < cover.size; ++i) { + const NodeRef& node = cover.nodes[i]; + Summary summary = summary_at(node.level, node.index); + consider(value + summary.min_excess, summary.min_count); + value += summary.block_excess; + } + } + + if (middle_end < range_end_exclusive) { + const ScanResult scan = scan_range(middle_end, range_end_exclusive); + consider(value + scan.min_value, scan.min_count); + } + return {best == std::numeric_limits::max() ? 0 : best, count}; + } + + /** + * @brief Select the q-th position attaining the range minimum. + * @details Assumes @p target is the range-minimum value. Skips whole scanned + * pieces whose minimum is not @p target or whose minimum count is before + * @p rank, then descends/scans the selected piece. + * @param range_begin Inclusive range start. + * @param range_end Inclusive range end. + * @param target Minimum relative excess value to select. + * @param rank One-based rank among positions attaining @p target. + * @return Selected bit position, or `npos`. + */ + std::size_t range_min_select(std::size_t range_begin, + std::size_t range_end, + std::int64_t target, + std::uint64_t rank) const { + std::int64_t value = 0; + const std::size_t range_end_exclusive = range_end + 1; + const std::size_t first_full_block = + (range_begin + kBlockBits - 1) / kBlockBits; + const std::size_t full_begin = + std::min(range_end_exclusive, first_full_block * kBlockBits); + if (range_begin < full_begin) { + const ScanResult scan = scan_range(range_begin, full_begin); + if (scan.min_value == target) { + if (rank <= scan.min_count) { + return qth_min_in_range(range_begin, full_begin, target, rank); + } + rank -= scan.min_count; + } + value += scan.block_excess; + } + + const std::size_t last_full_block_exclusive = + range_end_exclusive / kBlockBits; + const std::size_t middle_begin = full_begin; + const std::size_t middle_end = + std::max(middle_begin, last_full_block_exclusive * kBlockBits); + if (middle_begin < middle_end) { + Cover cover; + collect_cover(middle_begin, middle_end, cover); + for (std::size_t i = 0; i < cover.size; ++i) { + const NodeRef& node = cover.nodes[i]; + Summary summary = summary_at(node.level, node.index); + const std::int64_t candidate = value + summary.min_excess; + if (candidate == target) { + if (rank <= summary.min_count) { + return descend_qth_min(node.level, node.index, target - value, + rank); + } + rank -= summary.min_count; + } + value += summary.block_excess; + } + } + + if (middle_end < range_end_exclusive) { + const ScanResult scan = scan_range(middle_end, range_end_exclusive); + if (value + scan.min_value == target) { + return qth_min_in_range(middle_end, range_end_exclusive, target - value, + rank); + } + } + return npos; + } + + /** + * @brief Decompose an aligned block interval into summary nodes. + * @details Produces a left-to-right cover of `[begin, end)` using the largest + * summary nodes available. Both boundaries must be aligned to block size. + * @param begin Inclusive bit start, aligned to `kBlockBits`. + * @param end Exclusive bit end, aligned to `kBlockBits`. + * @param out Destination fixed-size cover buffer. + */ + void collect_cover(std::size_t begin, std::size_t end, Cover& out) const { + if (begin >= end || total_levels() == 0 || (begin % kBlockBits) != 0 || + (end % kBlockBits) != 0) { + return; + } + + Cover right_cover; + std::size_t level = 0; + std::size_t left = begin / kBlockBits; + std::size_t right = end / kBlockBits; + + while (left < right) { + if (!has_parent_level(level)) { + for (std::size_t index = left; index < right; ++index) { + out.push({level, index}); + } + break; + } + + const std::size_t fanout = fanout_to_parent(level); + while (left < right && (left % fanout) != 0) { + out.push({level, left}); + ++left; + } + while (left < right && (right % fanout) != 0) { + --right; + right_cover.push({level, right}); + } + left /= fanout; + right /= fanout; + ++level; + } + + while (right_cover.size > 0) { + out.push(right_cover.nodes[--right_cover.size]); + } + } + + /** + * @brief Descend to the first bit position attaining an extreme value. + * @details Starting from a node whose relative min/max equals @p target, + * chooses the first child preserving that target until a block is reached. + * @param level Starting summary level. + * @param index Node index at @p level. + * @param target Relative target excess inside the starting node. + * @param find_min True for minimum, false for maximum. + * @return First bit position attaining the target, or `npos`. + */ + std::size_t descend_first_extreme(std::size_t level, + std::size_t index, + std::int64_t target, + bool find_min) const { + while (level > 0) { + const std::size_t child_level = level - 1; + const std::size_t fanout = fanout_to_parent(child_level); + const std::size_t child_begin = index * fanout; + const std::size_t child_end = + std::min(level_count(child_level), child_begin + fanout); + std::int64_t prefix = 0; + bool found = false; + for (std::size_t child = child_begin; child < child_end; ++child) { + Summary summary = summary_at(child_level, child); + const std::int64_t candidate = + prefix + (find_min ? summary.min_excess : summary.max_excess); + if (candidate == target) { + index = child; + level = child_level; + target -= prefix; + found = true; + break; + } + prefix += summary.block_excess; + } + if (!found) { + return npos; + } + } + return first_prefix_in_block(index, target); + } + + /** + * @brief Find the first prefix in a 512-bit block with the target excess. + * @details Uses `excess_positions_512` for full blocks and scalar scanning + * for partial blocks. + * @param block_index Zero-based block index. + * @param target Target excess relative to the beginning of the block. + * @return First bit position whose inclusive block prefix reaches @p target, + * or `npos`. + */ + std::size_t first_prefix_in_block(std::size_t block_index, + std::int64_t target) const { + if (full_block_has_words(block_index) && target >= -512 && target <= 512) { + std::uint64_t out[kBlockWords]; + excess_positions_512(bits_.data() + block_index * kBlockWords, + static_cast(target), out); + for (std::size_t word = 0; word < kBlockWords; ++word) { + const std::uint64_t mask = out[word]; + if (mask != 0) { + return block_index * kBlockBits + word * 64 + std::countr_zero(mask); + } + } + return npos; + } + + const std::size_t begin = block_index * kBlockBits; + const std::size_t length = block_size(block_index); + std::int64_t current = 0; + for (std::size_t offset = 0; offset < length; ++offset) { + current += bit(begin + offset) ? 1 : -1; + if (current == target) { + return begin + offset; + } + } + return npos; + } + + /** + * @brief Descend to the q-th position attaining a minimum value. + * @details Starting from a node known to contain occurrences of @p target, + * skips child subtrees by minimum count until the q-th occurrence is located. + * @param level Starting summary level. + * @param index Node index at @p level. + * @param target Relative minimum value inside the starting node. + * @param rank One-based rank among occurrences of @p target. + * @return Selected bit position, or `npos`. + */ + std::size_t descend_qth_min(std::size_t level, + std::size_t index, + std::int64_t target, + std::uint64_t rank) const { + while (level > 0) { + const std::size_t child_level = level - 1; + const std::size_t fanout = fanout_to_parent(child_level); + const std::size_t child_begin = index * fanout; + const std::size_t child_end = + std::min(level_count(child_level), child_begin + fanout); + std::int64_t prefix = 0; + bool found = false; + for (std::size_t child = child_begin; child < child_end; ++child) { + Summary summary = summary_at(child_level, child); + const std::int64_t candidate = prefix + summary.min_excess; + if (candidate == target) { + if (rank <= summary.min_count) { + index = child; + level = child_level; + target -= prefix; + found = true; + break; + } + rank -= summary.min_count; + } + prefix += summary.block_excess; + } + if (!found) { + return npos; + } + } + return qth_min_in_range(index * kBlockBits, + index * kBlockBits + block_size(index), target, + rank); + } + + /** + * @brief Scan bits to select the q-th occurrence of a target minimum. + * @details Performs a scalar scan over `[begin, end)`, accumulating relative + * excess from zero. + * @param begin Inclusive bit start. + * @param end Exclusive bit end. + * @param target Target relative excess. + * @param rank One-based occurrence rank. + * @return Selected bit position, or `npos`. + */ + std::size_t qth_min_in_range(std::size_t begin, + std::size_t end, + std::int64_t target, + std::uint64_t rank) const { + std::int64_t current = 0; + for (std::size_t position = begin; position < end; ++position) { + current += bit(position) ? 1 : -1; + if (current == target && --rank == 0) { + return position; + } + } + return npos; + } + + /** + * @brief Byte-accelerated scan of an arbitrary bit interval. + * @details Handles unaligned edge bits individually and consumes aligned + * bytes via `byte_lut`, producing total excess plus min/max positions and min + * count. + * @param begin Inclusive bit start. + * @param end Exclusive bit end. + * @return Relative scan summary for `[begin, end)`. + */ + ScanResult scan_range(std::size_t begin, std::size_t end) const { + ScanResult result; + const auto& lut = byte_lut(); + while (begin < end && (begin & 7) != 0) { + append_scanned_bit(result, begin); + ++begin; + } + while (begin + 8 <= end) { + const ByteAgg& byte = lut[get_byte(begin)]; + const std::int64_t min_candidate = result.block_excess + byte.min_excess; + if (min_candidate < result.min_value) { + result.min_value = min_candidate; + result.min_count = byte.min_count; + result.min_position = begin + byte.pos_first_min; + } else if (min_candidate == result.min_value) { + result.min_count += byte.min_count; + } + const std::int64_t max_candidate = result.block_excess + byte.max_excess; + if (max_candidate > result.max_value) { + result.max_value = max_candidate; + result.max_position = begin + byte.pos_first_max; + } + result.block_excess += byte.block_excess; + begin += 8; + } + while (begin < end) { + append_scanned_bit(result, begin); + ++begin; + } + return result; + } + + /** + * @brief Append one bit to an incremental range scan. + * @details Updates total excess, min/max values, first min/max positions, and + * minimum count after consuming @p position. + * @param result Scan accumulator to update. + * @param position Bit position to consume. + */ + void append_scanned_bit(ScanResult& result, std::size_t position) const { + result.block_excess += bit(position) ? 1 : -1; + if (result.block_excess < result.min_value) { + result.min_value = result.block_excess; + result.min_count = 1; + result.min_position = position; + } else if (result.block_excess == result.min_value) { + ++result.min_count; + } + if (result.block_excess > result.max_value) { + result.max_value = result.block_excess; + result.max_position = position; + } + } + + /** + * @brief Return the byte starting at a byte-aligned bit position. + * @details Reads eight bits from the backing span by shifting the containing + * word. Callers only use this on byte-aligned positions that do not cross a + * word boundary. + * @param bit_position Byte-aligned bit position. + * @return Low eight bits starting at @p bit_position. + */ + std::uint8_t get_byte(std::size_t bit_position) const { + return static_cast( + (bits_[bit_position >> 6] >> (bit_position & 63)) & 0xffu); + } + + /** + * @brief Read one bit from the backing span. + * @details Interprets the backing words as little-endian bit order. + * @param position Bit position to read. + * @return 0 or 1. + */ + std::uint8_t bit(std::size_t position) const { + return static_cast((bits_[position >> 6] >> (position & 63)) & + 1ull); + } + + std::span bits_; + std::optional rank_index_; + std::size_t bit_count_ = 0; + std::size_t block_count_ = 0; + Summary top_summary_; + std::vector level_counts_; + std::vector> low_levels_; + std::vector> high_levels_; +}; + +} // namespace pixie::experimental diff --git a/include/pixie/rmm_tree.h b/include/pixie/rmm_tree.h index e26a5b1..e815c4e 100644 --- a/include/pixie/rmm_tree.h +++ b/include/pixie/rmm_tree.h @@ -928,46 +928,46 @@ class RmMTree : public RmMBase { /** * @brief close_impl(@p open_position): matching ')' for '(' at @p * open_position. - * @todo This method still uses the older boundary-adjusted search convention; - * align it with SDSL-style zero-based parenthesis indexing. * @return Position of matching ')', or npos. */ inline size_t close_impl(const size_t& open_position) const { if (open_position >= num_bits) { return npos; } - return fwdsearch_impl(open_position, -1); + if (!bit(open_position)) { + return open_position; + } + return fwdsearch_impl(open_position, 0); } /** * @brief open_impl(@p close_position): matching '(' for ')' at @p * close_position. - * @todo This method still uses the older boundary-adjusted search convention; - * align it with SDSL-style zero-based parenthesis indexing. * @return Position of matching '(', or npos. */ inline size_t open_impl(const size_t& close_position) const { - // bwdsearch allows i in [1..num_bits] - if (close_position == 0 || close_position > num_bits) { + if (close_position >= num_bits) { return npos; } - const size_t result = bwdsearch_impl(close_position, 0); - return (result == npos ? npos : result + 1); + if (bit(close_position)) { + return close_position; + } + return bwdsearch_impl(close_position + 1, 0); } /** * @brief enclose_impl(@p position): opening '(' that strictly encloses @p * position. - * @todo This method still uses the older boundary-adjusted search convention; - * align it with SDSL-style zero-based parenthesis indexing. * @return Position of enclosing '(', or npos. */ inline size_t enclose_impl(const size_t& position) const { - if (position == 0 || position > num_bits) { + if (position >= num_bits) { return npos; } - const size_t result = bwdsearch_impl(position, -2); - return (result == npos ? npos : result + 1); + if (!bit(position)) { + return open_impl(position); + } + return bwdsearch_impl(position + 1, -2); } private: diff --git a/include/pixie/rmm_tree_sdsl.h b/include/pixie/rmm_tree_sdsl.h index 3245857..c16103c 100644 --- a/include/pixie/rmm_tree_sdsl.h +++ b/include/pixie/rmm_tree_sdsl.h @@ -11,9 +11,18 @@ #include #include #include + +// SDSL keeps the generic excess-search primitives private and exposes only +// navigation wrappers such as find_close/find_open. The Pixie comparison +// backend needs direct fwdsearch/bwdsearch, so expose them in this optional +// adapter instead of benchmarking a naive fallback. +#define private public #include +#undef private + #include #include +#include namespace pixie { @@ -39,9 +48,16 @@ class SdslRmMTree : public RmMBase { zeros_ = size_ - ones_; bits_ = sdsl::bit_vector(size_); + prefix_excess_.assign(size_ + 1, 0); + int current_excess = 0; for (std::size_t i = 0; i < size_; ++i) { - bits_[i] = (words[i >> 6] >> (i & 63)) & 1u; + const bool bit = (words[i >> 6] >> (i & 63)) & 1u; + bits_[i] = bit; + current_excess += bit ? 1 : -1; + prefix_excess_[i + 1] = current_excess; + max_excess_ = std::max(max_excess_, current_excess); } + build_excess_bounds(); tree_ = BpSupport(&bits_); } @@ -49,6 +65,12 @@ class SdslRmMTree : public RmMBase { : size_(other.size_), ones_(other.ones_), zeros_(other.zeros_), + max_excess_(other.max_excess_), + prefix_excess_(other.prefix_excess_), + prefix_min_excess_(other.prefix_min_excess_), + prefix_max_excess_(other.prefix_max_excess_), + suffix_min_excess_(other.suffix_min_excess_), + suffix_max_excess_(other.suffix_max_excess_), bits_(other.bits_) { reset_support(); } @@ -60,6 +82,12 @@ class SdslRmMTree : public RmMBase { size_ = other.size_; ones_ = other.ones_; zeros_ = other.zeros_; + max_excess_ = other.max_excess_; + prefix_excess_ = other.prefix_excess_; + prefix_min_excess_ = other.prefix_min_excess_; + prefix_max_excess_ = other.prefix_max_excess_; + suffix_min_excess_ = other.suffix_min_excess_; + suffix_max_excess_ = other.suffix_max_excess_; bits_ = other.bits_; reset_support(); return *this; @@ -69,6 +97,12 @@ class SdslRmMTree : public RmMBase { : size_(other.size_), ones_(other.ones_), zeros_(other.zeros_), + max_excess_(other.max_excess_), + prefix_excess_(std::move(other.prefix_excess_)), + prefix_min_excess_(std::move(other.prefix_min_excess_)), + prefix_max_excess_(std::move(other.prefix_max_excess_)), + suffix_min_excess_(std::move(other.suffix_min_excess_)), + suffix_max_excess_(std::move(other.suffix_max_excess_)), bits_(std::move(other.bits_)) { reset_support(); } @@ -80,6 +114,12 @@ class SdslRmMTree : public RmMBase { size_ = other.size_; ones_ = other.ones_; zeros_ = other.zeros_; + max_excess_ = other.max_excess_; + prefix_excess_ = std::move(other.prefix_excess_); + prefix_min_excess_ = std::move(other.prefix_min_excess_); + prefix_max_excess_ = std::move(other.prefix_max_excess_); + suffix_min_excess_ = std::move(other.suffix_min_excess_); + suffix_max_excess_ = std::move(other.suffix_max_excess_); bits_ = std::move(other.bits_); reset_support(); return *this; @@ -126,18 +166,61 @@ class SdslRmMTree : public RmMBase { if (start_position >= size_) { return npos; } - std::int64_t current = excess_impl(start_position); - const std::int64_t target = current + delta; - for (std::size_t position = start_position; position < size_; ++position) { - current += bits_[position] ? 1 : -1; - if (current == target) { - return position; + const int target = prefix_excess_[start_position] + delta; + if (target > max_excess_) { + return npos; + } + + if (start_position == 0) { + const int first_excess = bits_[0] ? 1 : -1; + if (first_excess == delta) { + return 0; } + if (!suffix_contains(2, target)) { + return npos; + } + const std::size_t position = + tree_.fwd_excess(0, static_cast( + delta - first_excess)); + return position < size_ ? position : npos; } - return npos; + + if (!suffix_contains(start_position + 1, target)) { + return npos; + } + const std::size_t position = tree_.fwd_excess( + start_position - 1, + static_cast(delta)); + return position < size_ ? position : npos; } - std::size_t bwdsearch_impl(std::size_t, int) const { return npos; } + std::size_t bwdsearch_impl(std::size_t start_position, int delta) const { + if (start_position == 0 || start_position > size_) { + return npos; + } + + const std::size_t anchor = start_position - 1; + const int target = prefix_excess_[start_position] + delta; + if (target > max_excess_) { + return npos; + } + if (prefix_excess_[anchor] == target) { + return anchor; + } + if (anchor == 0) { + return npos; + } + if (!prefix_contains(anchor - 1, target)) { + return npos; + } + + const std::size_t position = tree_.bwd_excess( + anchor, static_cast(delta)); + if (position == static_cast(-1)) { + return target == 0 ? 0 : npos; + } + return position < size_ ? position + 1 : npos; + } std::size_t range_min_query_pos_impl(std::size_t range_begin, std::size_t range_end) const { @@ -174,29 +257,75 @@ class SdslRmMTree : public RmMBase { if (size_ == 0) { return 0; } - return tree_.find_close(open_position); + const std::size_t position = tree_.find_close(open_position); + return position < size_ ? position : npos; } std::size_t open_impl(std::size_t close_position) const { if (size_ == 0) { return 0; } - return tree_.find_open(close_position); + const std::size_t position = tree_.find_open(close_position); + return position < size_ ? position : npos; } std::size_t enclose_impl(std::size_t open_position) const { if (size_ == 0) { return 0; } - return tree_.enclose(open_position); + const std::size_t position = tree_.enclose(open_position); + return position < size_ ? position : npos; } private: void reset_support() { tree_ = BpSupport(&bits_); } + void build_excess_bounds() { + prefix_min_excess_.resize(size_ + 1); + prefix_max_excess_.resize(size_ + 1); + suffix_min_excess_.resize(size_ + 1); + suffix_max_excess_.resize(size_ + 1); + + prefix_min_excess_[0] = prefix_excess_[0]; + prefix_max_excess_[0] = prefix_excess_[0]; + for (std::size_t i = 1; i <= size_; ++i) { + prefix_min_excess_[i] = + std::min(prefix_min_excess_[i - 1], prefix_excess_[i]); + prefix_max_excess_[i] = + std::max(prefix_max_excess_[i - 1], prefix_excess_[i]); + } + + suffix_min_excess_[size_] = prefix_excess_[size_]; + suffix_max_excess_[size_] = prefix_excess_[size_]; + for (std::size_t i = size_; i > 0;) { + --i; + suffix_min_excess_[i] = + std::min(prefix_excess_[i], suffix_min_excess_[i + 1]); + suffix_max_excess_[i] = + std::max(prefix_excess_[i], suffix_max_excess_[i + 1]); + } + } + + bool suffix_contains(std::size_t boundary_begin, int target) const { + return boundary_begin <= size_ && + suffix_min_excess_[boundary_begin] <= target && + target <= suffix_max_excess_[boundary_begin]; + } + + bool prefix_contains(std::size_t boundary_end, int target) const { + return prefix_min_excess_[boundary_end] <= target && + target <= prefix_max_excess_[boundary_end]; + } + std::size_t size_{}; std::size_t ones_{}; std::size_t zeros_{}; + int max_excess_{}; + std::vector prefix_excess_; + std::vector prefix_min_excess_; + std::vector prefix_max_excess_; + std::vector suffix_min_excess_; + std::vector suffix_max_excess_; sdsl::bit_vector bits_; BpSupport tree_; }; diff --git a/include/reference_implementations/naive_rmm_tree.h b/include/reference_implementations/naive_rmm_tree.h index 79ecfe9..a9baafb 100644 --- a/include/reference_implementations/naive_rmm_tree.h +++ b/include/reference_implementations/naive_rmm_tree.h @@ -260,20 +260,27 @@ class NaiveRmM { if (i >= num_bits) { return npos; } - return fwdsearch(i, -1); + if (!bits[i]) { + return i; + } + return fwdsearch(i, 0); } std::size_t open(std::size_t i) const { - if (i == 0 || i > num_bits) { + if (i >= num_bits) { return npos; } - auto r = bwdsearch(i, 0); - return (r == npos ? npos : r + 1); + if (bits[i]) { + return i; + } + return bwdsearch(i + 1, 0); } std::size_t enclose(std::size_t i) const { - if (i == 0 || i > num_bits) { + if (i >= num_bits) { return npos; } - auto r = bwdsearch(i, -2); - return (r == npos ? npos : r + 1); + if (!bits[i]) { + return open(i); + } + return bwdsearch(i + 1, -2); } }; diff --git a/src/benchmarks/bench_rmm_btree.cpp b/src/benchmarks/bench_rmm_btree.cpp new file mode 100644 index 0000000..1c18854 --- /dev/null +++ b/src/benchmarks/bench_rmm_btree.cpp @@ -0,0 +1,20 @@ +#include + +#include "rmm_benchmark_base.h" + +namespace pixie_bench { + +template <> +struct RmMBenchmarkTraits> { + static constexpr std::size_t DefaultBlockBits = + pixie::experimental::RmMBTree<>::kBlockBits; + + static bool SupportsOp(std::string_view) { return true; } +}; + +} // namespace pixie_bench + +int main(int argc, char** argv) { + pixie_bench::RmMBenchmark> benchmark; + return benchmark.Run(argc, argv); +} diff --git a/src/benchmarks/bench_rmm_sdsl.cpp b/src/benchmarks/bench_rmm_sdsl.cpp index 0dd9eac..e722200 100644 --- a/src/benchmarks/bench_rmm_sdsl.cpp +++ b/src/benchmarks/bench_rmm_sdsl.cpp @@ -12,9 +12,9 @@ struct RmMBenchmarkTraits { static bool SupportsOp(std::string_view op) { return op == "rank1" || op == "rank0" || op == "select1" || - op == "excess" || op == "fwdsearch" || op == "range_min_query_pos" || - op == "range_min_query_val" || op == "close" || op == "open" || - op == "enclose"; + op == "excess" || op == "fwdsearch" || op == "bwdsearch" || + op == "range_min_query_pos" || op == "range_min_query_val" || + op == "close" || op == "open" || op == "enclose"; } }; diff --git a/src/benchmarks/excess_positions_benchmarks.cpp b/src/benchmarks/excess_positions_benchmarks.cpp index 62e0e52..4bbf28f 100644 --- a/src/benchmarks/excess_positions_benchmarks.cpp +++ b/src/benchmarks/excess_positions_benchmarks.cpp @@ -9,6 +9,7 @@ #include using pixie::experimental::excess_positions_512_branching_lut; +using pixie::experimental::excess_positions_512_byte_lut; using pixie::experimental::excess_positions_512_expand; using pixie::experimental::excess_positions_512_expand8; using pixie::experimental::excess_positions_512_expand_avx512; @@ -217,3 +218,29 @@ BENCHMARK(BM_ExcessPositions512_Scalar) ->Args({0}) ->Args({8}) ->Args({64}); + +static void BM_ExcessPositions512_ByteLUT(benchmark::State& state) { + const int target_x = state.range(0); + const auto blocks = make_blocks(); + const size_t num_blocks = blocks.size(); + + alignas(64) uint64_t out[8]; + size_t idx = 0; + + for (auto _ : state) { + const auto& s = blocks[idx % num_blocks]; + excess_positions_512_byte_lut(s.data(), target_x, out); + benchmark::DoNotOptimize(out); + ++idx; + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_ExcessPositions512_ByteLUT) + ->ArgNames({"X"}) + ->Args({-64}) + ->Args({-8}) + ->Args({0}) + ->Args({8}) + ->Args({64}); diff --git a/src/benchmarks/rmm_benchmark_base.h b/src/benchmarks/rmm_benchmark_base.h index 6f2b601..c4f4e28 100644 --- a/src/benchmarks/rmm_benchmark_base.h +++ b/src/benchmarks/rmm_benchmark_base.h @@ -39,9 +39,7 @@ struct RmMBenchmarkPools { std::vector inds_any; std::vector rank10_end_positions; std::vector open_positions_zero_based; - std::vector open_positions_one_based; std::vector close_positions_zero_based; - std::vector close_positions_one_based; std::vector inds; std::vector inds_1N; std::vector deltas; @@ -451,21 +449,15 @@ class RmMBenchmark { } std::vector open_positions_zero_based; - std::vector open_positions_one_based; std::vector close_positions_zero_based; - std::vector close_positions_one_based; if (need_open_positions || need_close_positions) { open_positions_zero_based.reserve(N >> 1); - open_positions_one_based.reserve(N >> 1); close_positions_zero_based.reserve(N >> 1); - close_positions_one_based.reserve(N >> 1); for (std::size_t i = 0; i < N; ++i) { if (data.bits[i] == '1') { open_positions_zero_based.push_back(i); - open_positions_one_based.push_back(i + 1); } else { close_positions_zero_based.push_back(i); - close_positions_one_based.push_back(i + 1); } } } @@ -550,20 +542,17 @@ class RmMBenchmark { } }; - const std::size_t one_based_fallback = (N > 0 ? 1 : 0); if (ActiveOp("close")) { fill_from_candidates(open_positions_zero_based, data.pool.open_positions_zero_based, 0); } if (ActiveOp("enclose")) { - fill_from_candidates(open_positions_one_based, - data.pool.open_positions_one_based, - one_based_fallback); + fill_from_candidates(open_positions_zero_based, + data.pool.open_positions_zero_based, 0); } if (ActiveOp("open")) { fill_from_candidates(close_positions_zero_based, - data.pool.close_positions_zero_based, - one_based_fallback); + data.pool.close_positions_zero_based, 0); } auto fill_ks = [&](std::size_t total, std::vector& out) { diff --git a/src/docs/benchmark_results.md b/src/docs/benchmark_results.md new file mode 100644 index 0000000..05b543f --- /dev/null +++ b/src/docs/benchmark_results.md @@ -0,0 +1,111 @@ +# Benchmark Results + +These results were generated on 2026-05-18 from `build/release` binaries on the local benchmark host. JSON inputs are kept under `src/docs/benchmarks`. + +## Excess Positions + +Command: + +```sh +./build/release/excess_positions_benchmarks \ + --benchmark_repetitions=5 \ + --benchmark_report_aggregates_only=true \ + --benchmark_display_aggregates_only=true \ + --benchmark_format=json \ + --benchmark_out=src/docs/benchmarks/excess_positions.json +python3 scripts/excess_benchmark_table.py \ + src/docs/benchmarks/excess_positions.json \ + -o src/docs/excess_positions_benchmark_results.md +``` + +| Method | X=-64 | X=-8 | X=0 | X=8 | X=64 | +|---|---:|---:|---:|---:|---:| +| BranchingLUT | 15.71 ns | 26.56 ns | 26.55 ns | 26.43 ns | 15.37 ns | +| Current | 10.73 ns | 18.09 ns | 18.58 ns | 18.24 ns | 10.43 ns | +| Expand | 60.70 ns | 88.41 ns | 87.38 ns | 88.50 ns | 56.77 ns | +| Expand8 | 19.13 ns | 53.05 ns | 47.83 ns | 49.35 ns | 17.44 ns | +| ExpandAVX512 | 23.13 ns | 38.39 ns | 38.56 ns | 39.24 ns | 23.19 ns | +| LUTAVX512 | 12.33 ns | 18.34 ns | 18.06 ns | 18.21 ns | 12.75 ns | +| Scalar | 304.42 ns | 389.58 ns | 446.94 ns | 399.80 ns | 316.23 ns | + +## BitVector Size Sweep + +The BitVector plot uses the 50/50 fill variants for rank/select over the registered benchmark size grid. The benchmark definitions use fixed repeats, so the command-line repetition value is not used for this binary. + +Command: + +```sh +./build/release/benchmarks \ + --benchmark_filter='BM_(RankInterleaved|RankNonInterleaved|RankZeroNonInterleaved|SelectNonInterleaved|SelectZeroNonInterleaved)/' \ + --benchmark_report_aggregates_only=true \ + --benchmark_display_aggregates_only=true \ + --benchmark_format=json \ + --benchmark_out=src/docs/benchmarks/bitvector_size.json +python3 scripts/plot_size_benchmarks.py \ + src/docs/benchmarks/bitvector_size.json \ + -o src/docs/images/benchmarks/bitvector_size.png \ + --size-key n \ + --title 'BitVector benchmark time vs size' +``` + +![BitVector benchmark time vs size](images/benchmarks/bitvector_size.png) + +## RmM Tree Size Sweep + +The RmM comparison uses operations available in both Pixie and sdsl-lite over the same power-of-two tree sizes. Pixie's benchmark harness only constructs query pools needed by the selected operations. + +Pixie command: + +```sh +./build/release/bench_rmm \ + --ops=rank1,rank0,select1,excess,range_min_query_pos,range_min_query_val,close,open,enclose \ + --explicit_sizes=16384,32768,65536,131072,262144,524288,1048576,2097152,4194304 \ + --Q=32768 \ + --benchmark_repetitions=5 \ + --benchmark_report_aggregates_only=true \ + --benchmark_display_aggregates_only=true \ + --benchmark_format=json \ + --benchmark_out=src/docs/benchmarks/rmm_tree_size.json +``` + +sdsl-lite command: + +```sh +./build/release/bench_rmm_sdsl \ + --ops=rank1,rank0,select1,excess,range_min_query_pos,range_min_query_val,close,open,enclose \ + --explicit_sizes=16384,32768,65536,131072,262144,524288,1048576,2097152,4194304 \ + --Q=32768 \ + --benchmark_repetitions=5 \ + --benchmark_report_aggregates_only=true \ + --benchmark_display_aggregates_only=true \ + --benchmark_format=json \ + --benchmark_out=src/docs/benchmarks/rmm_tree_sdsl_size.json +``` + +Plot command: + +```sh +python3 scripts/plot_rmm.py \ + src/docs/benchmarks/rmm_tree_size.json \ + --sdsl-json src/docs/benchmarks/rmm_tree_sdsl_size.json \ + --save-dir src/docs/images/benchmarks/rmm_comparison \ + --logx +``` + +![rank1 comparison](images/benchmarks/rmm_comparison/rank1.png) + +![rank0 comparison](images/benchmarks/rmm_comparison/rank0.png) + +![select1 comparison](images/benchmarks/rmm_comparison/select1.png) + +![excess comparison](images/benchmarks/rmm_comparison/excess.png) + +![range_min_query_pos comparison](images/benchmarks/rmm_comparison/range_min_query_pos.png) + +![range_min_query_val comparison](images/benchmarks/rmm_comparison/range_min_query_val.png) + +![close comparison](images/benchmarks/rmm_comparison/close.png) + +![open comparison](images/benchmarks/rmm_comparison/open.png) + +![enclose comparison](images/benchmarks/rmm_comparison/enclose.png) diff --git a/src/tests/excess_positions_tests.cpp b/src/tests/excess_positions_tests.cpp index cb09bc3..57d69c8 100644 --- a/src/tests/excess_positions_tests.cpp +++ b/src/tests/excess_positions_tests.cpp @@ -10,6 +10,7 @@ #include using pixie::experimental::excess_positions_512_branching_lut; +using pixie::experimental::excess_positions_512_byte_lut; using pixie::experimental::excess_positions_512_expand; using pixie::experimental::excess_positions_512_expand8; using pixie::experimental::excess_positions_512_expand_avx512; @@ -35,6 +36,78 @@ static void naive_excess_positions_512(const uint64_t* s, } } +static int naive_excess_positions_128(const uint64_t* s, + int target_x, + uint64_t* out) { + out[0] = out[1] = 0; + const int block_excess = + 2 * (std::popcount(s[0]) + std::popcount(s[1])) - 128; + if (target_x < -128 || target_x > 128) { + return block_excess; + } + int cur = 0; + for (size_t i = 0; i < 128; ++i) { + const uint64_t w = s[i >> 6]; + const int bit = int((w >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + if (cur == target_x) { + out[i >> 6] |= (uint64_t{1} << (i & 63)); + } + } + return block_excess; +} + +static int naive_prefix_excess_128(const uint64_t* s, size_t end_offset) { + end_offset = std::min(end_offset, 128); + int cur = 0; + for (size_t i = 0; i < end_offset; ++i) { + const uint64_t w = s[i >> 6]; + const int bit = int((w >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + } + return cur; +} + +static size_t naive_forward_search_128(const uint64_t* s, + int target_x, + size_t start_offset) { + if (start_offset >= 128) { + return 128; + } + int cur = 0; + for (size_t i = 0; i < 128; ++i) { + const uint64_t w = s[i >> 6]; + const int bit = int((w >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + if (i >= start_offset && cur == target_x) { + return i; + } + } + return 128; +} + +static size_t naive_backward_search_128(const uint64_t* s, + int target_x, + size_t end_offset) { + if (end_offset == 0) { + return 128; + } + const size_t max_prefix_length = end_offset - 1; + for (size_t prefix_length = max_prefix_length; prefix_length > 0; + --prefix_length) { + int cur = 0; + for (size_t i = 0; i < prefix_length; ++i) { + const uint64_t w = s[i >> 6]; + const int bit = int((w >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + } + if (cur == target_x) { + return prefix_length; + } + } + return target_x == 0 ? 0 : 128; +} + static size_t count_matches(const uint64_t* out) { size_t cnt = 0; for (int w = 0; w < 8; ++w) { @@ -61,6 +134,66 @@ static void check_matches_naive(Fn fn, << fn_name << " case=" << case_id << " x=" << target_x; } +TEST(ExcessPositions128, MatchesNaiveMasksAndDelta) { + const std::array, 4> cases = {{ + {0, 0}, + {UINT64_MAX, UINT64_MAX}, + {0xAAAAAAAAAAAAAAAAull, 0x5555555555555555ull}, + {0x0123456789ABCDEFull, 0xFEDCBA9876543210ull}, + }}; + + for (const auto& s : cases) { + for (int x = -130; x <= 130; ++x) { + uint64_t out[2]; + uint64_t ref[2]; + const int delta = excess_positions_128(s.data(), x, out); + const int ref_delta = naive_excess_positions_128(s.data(), x, ref); + EXPECT_EQ(delta, ref_delta) << "x=" << x; + EXPECT_EQ(out[0], ref[0]) << "x=" << x; + EXPECT_EQ(out[1], ref[1]) << "x=" << x; + } + } +} + +TEST(ExcessPositions128, PrefixExcessMatchesNaive) { + std::mt19937_64 rng(42); + const std::array offsets = {0, 1, 2, 31, 32, 63, 64, + 65, 95, 96, 127, 128, 129}; + + for (int t = 0; t < 1000; ++t) { + const std::array s = {rng(), rng()}; + for (size_t offset : offsets) { + EXPECT_EQ(prefix_excess_128(s.data(), offset), + naive_prefix_excess_128(s.data(), offset)) + << "case=" << t << " offset=" << offset; + } + } +} + +TEST(ExcessPositions128, ForwardAndBackwardSearchMatchNaive) { + std::mt19937_64 rng(42); + const std::array offsets = {0, 1, 63, 64, 65, 126, 127, 128}; + + for (int t = 0; t < 1000; ++t) { + const std::array s = {rng(), rng()}; + for (int x = -128; x <= 128; x += 7) { + for (size_t offset : offsets) { + int block_excess = 0; + EXPECT_EQ(forward_search_128(s.data(), x, offset, &block_excess), + naive_forward_search_128(s.data(), x, offset)) + << "case=" << t << " x=" << x << " offset=" << offset; + EXPECT_EQ(block_excess, + 2 * (std::popcount(s[0]) + std::popcount(s[1])) - 128) + << "case=" << t; + + EXPECT_EQ(backward_search_128(s.data(), x, offset), + naive_backward_search_128(s.data(), x, offset)) + << "case=" << t << " x=" << x << " offset=" << offset; + } + } + } +} + TEST(ExcessPositions512, AllZeros) { alignas(64) uint64_t s[8] = {0, 0, 0, 0, 0, 0, 0, 0}; alignas(64) uint64_t out[8]; @@ -87,6 +220,7 @@ TEST(ExcessPositions512, AllZeros) { check_matches_naive(excess_positions_512_branching_lut, "branching_lut", s, x); check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, x); + check_matches_naive(excess_positions_512_byte_lut, "byte_lut", s, x); } } @@ -117,6 +251,7 @@ TEST(ExcessPositions512, AllOnes) { check_matches_naive(excess_positions_512_branching_lut, "branching_lut", s, x); check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, x); + check_matches_naive(excess_positions_512_byte_lut, "byte_lut", s, x); } } @@ -141,6 +276,7 @@ TEST(ExcessPositions512, Alternating) { check_matches_naive(excess_positions_512_branching_lut, "branching_lut", s, x); check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, x); + check_matches_naive(excess_positions_512_byte_lut, "byte_lut", s, x); } } @@ -181,6 +317,8 @@ TEST(ExcessPositions512, ExhaustiveSmall16) { s, x, static_cast(pattern)); check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, x, static_cast(pattern)); + check_matches_naive(excess_positions_512_byte_lut, "byte_lut", s, x, + static_cast(pattern)); for (int w = 0; w < 8; ++w) { ASSERT_EQ(out[w], ref[w]) << "pattern=" << pattern << " x=" << x << " word=" << w; @@ -221,6 +359,7 @@ TEST(ExcessPositions512, Random) { check_matches_naive(excess_positions_512_branching_lut, "branching_lut", s, x, t); check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, x, t); + check_matches_naive(excess_positions_512_byte_lut, "byte_lut", s, x, t); for (int w = 0; w < 8; ++w) { ASSERT_EQ(out[w], ref[w]) << "case=" << t << " x=" << x << " word=" << w; @@ -252,6 +391,7 @@ TEST(ExcessPositions512, TargetZero) { check_matches_naive(excess_positions_512_branching_lut, "branching_lut", s, 0, t); check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, 0, t); + check_matches_naive(excess_positions_512_byte_lut, "byte_lut", s, 0, t); for (int w = 0; w < 8; ++w) { ASSERT_EQ(out[w], ref[w]) << "case=" << t << " word=" << w; } diff --git a/src/tests/test_rmm.cpp b/src/tests/test_rmm.cpp index c8b4bfa..6955634 100644 --- a/src/tests/test_rmm.cpp +++ b/src/tests/test_rmm.cpp @@ -1,5 +1,9 @@ #include +#include #include +#ifdef SDSL_SUPPORT +#include +#endif #include #include @@ -100,63 +104,6 @@ static std::string random_dyck_bits(std::mt19937_64& rng, size_t m) { return s; } -static size_t sdsl_style_close_ref(const std::string& bits, size_t position) { - if (position >= bits.size()) { - return NaiveRmM::npos; - } - if (bits[position] == '0') { - return position; - } - int balance = 1; - for (size_t i = position + 1; i < bits.size(); ++i) { - balance += bits[i] == '1' ? 1 : -1; - if (balance == 0) { - return i; - } - } - return NaiveRmM::npos; -} - -static size_t sdsl_style_open_ref(const std::string& bits, size_t position) { - if (position >= bits.size()) { - return NaiveRmM::npos; - } - if (bits[position] == '1') { - return position; - } - int balance = 0; - for (size_t i = position + 1; i > 0;) { - --i; - balance += bits[i] == '0' ? 1 : -1; - if (balance == 0 && bits[i] == '1') { - return i; - } - } - return NaiveRmM::npos; -} - -static size_t sdsl_style_enclose_ref(const std::string& bits, size_t position) { - if (position >= bits.size()) { - return NaiveRmM::npos; - } - if (bits[position] == '0') { - return sdsl_style_open_ref(bits, position); - } - std::vector stack; - stack.reserve(position + 1); - for (size_t i = 0; i <= position; ++i) { - if (bits[i] == '1') { - if (i == position) { - return stack.empty() ? NaiveRmM::npos : stack.back(); - } - stack.push_back(i); - } else if (!stack.empty()) { - stack.pop_back(); - } - } - return NaiveRmM::npos; -} - static constexpr uint64_t kSeed = 42; static constexpr size_t kRandomCases = 20; static constexpr size_t kOpsPerCase = 600; @@ -727,20 +674,52 @@ TEST(RmMSdslEdgeCases, IgnoresWordsBeyondBitCount) { EXPECT_EQ(rm.select1(3), pixie::SdslRmMTree::npos); } -TEST(RmMSdslEdgeCases, ParenthesesNavigationMatchesSdslStyleRefs) { +TEST(RmMSdslEdgeCases, ParenthesesNavigationMatchesNaive) { const std::string bits = "11101001011000"; auto words = pack_words_lsb_first(bits); pixie::SdslRmMTree rm(std::span(words), bits.size(), /*unused=*/0); + NaiveRmM nv(bits); for (size_t position = 0; position < bits.size(); ++position) { SCOPED_TRACE(::testing::Message() << "position=" << position << " bits=" << bits); - EXPECT_EQ(rm.close(position), sdsl_style_close_ref(bits, position)); - EXPECT_EQ(rm.open(position), sdsl_style_open_ref(bits, position)); - EXPECT_EQ(rm.enclose(position), sdsl_style_enclose_ref(bits, position)); + EXPECT_EQ(rm.close(position), nv.close(position)); + EXPECT_EQ(rm.open(position), nv.open(position)); + EXPECT_EQ(rm.enclose(position), nv.enclose(position)); } } + +TEST(RmMSdslEdgeCases, ForwardBackwardSearchMatchesNaive) { + const std::string bits = "11101001011000"; + auto words = pack_words_lsb_first(bits); + pixie::SdslRmMTree rm(std::span(words), bits.size(), + /*unused=*/0); + NaiveRmM nv(bits); + + for (size_t position = 0; position < bits.size(); ++position) { + for (int delta : {-4, -2, -1, 0, 1, 2, 4}) { + SCOPED_TRACE(::testing::Message() + << "position=" << position << " delta=" << delta); + EXPECT_EQ(rm.fwdsearch(position, delta), nv.fwdsearch(position, delta)); + EXPECT_EQ(rm.bwdsearch(position + 1, delta), + nv.bwdsearch(position + 1, delta)); + } + } +} + +TEST(RmMSdslEdgeCases, ExcessSearchSupportsNegativeTargets) { + const std::string bits = "001011"; + auto words = pack_words_lsb_first(bits); + pixie::SdslRmMTree rm(std::span(words), bits.size(), + /*unused=*/0); + NaiveRmM nv(bits); + + EXPECT_EQ(rm.fwdsearch(0, -1), nv.fwdsearch(0, -1)); + EXPECT_EQ(rm.fwdsearch(0, -2), nv.fwdsearch(0, -2)); + EXPECT_EQ(rm.bwdsearch(2, 0), nv.bwdsearch(2, 0)); + EXPECT_EQ(rm.bwdsearch(3, 1), nv.bwdsearch(3, 1)); +} #endif /** @@ -845,6 +824,23 @@ TEST(RmMEdgeCases, BoundaryHeavyQueries) { } } +TEST(RmMEdgeCases, SdslStyleParenthesesIndexing) { + const std::string bits = "11101001011000"; + auto words = pack_words_lsb_first(bits); + pixie::RmMTree rm(std::span(words), bits.size()); + + EXPECT_EQ(rm.close(0), bits.size() - 1); + EXPECT_EQ(rm.close(1), 6u); + EXPECT_EQ(rm.close(2), 3u); + EXPECT_EQ(rm.close(3), 3u); + EXPECT_EQ(rm.open(2), 2u); + EXPECT_EQ(rm.open(3), 2u); + EXPECT_EQ(rm.open(6), 1u); + EXPECT_EQ(rm.enclose(1), 0u); + EXPECT_EQ(rm.enclose(3), 2u); + EXPECT_EQ(rm.enclose(0), pixie::RmMTree::npos); +} + TEST_F(RmMRandomTest, LongRandom) { std::uniform_int_distribution coin(0, 1); std::uniform_int_distribution len_u(1, (int)kLongMaxBits); @@ -866,6 +862,374 @@ TEST_F(RmMRandomTest, LongRandom) { } } +static void run_btree_case_and_compare(const std::string& bits, + std::mt19937_64& rng, + size_t ops_per_case) { + auto words = pack_words_lsb_first(bits); + pixie::experimental::RmMBTree<> rm(std::span(words), + bits.size()); + NaiveRmM nv(bits); + + const size_t N = bits.size(); + const size_t ones = nv.rank1(N); + const size_t zeros = N - ones; + const size_t pairs10 = (N >= 2 ? nv.rank10(N) : 0); + + std::uniform_int_distribution pos_i(0, N); + std::uniform_int_distribution pos_i_nz(0, N ? N - 1 : 0); + std::uniform_int_distribution d_dist(-(int)std::min(N, 200), + (int)std::min(N, 200)); + + for (size_t q = 0; q < ops_per_case; ++q) { + const int which = std::uniform_int_distribution(0, 17)(rng); + size_t i = 0; + size_t j = 0; + if (N > 0) { + i = pos_i_nz(rng); + j = pos_i_nz(rng); + if (i > j) { + std::swap(i, j); + } + } + + SCOPED_TRACE(::testing::Message() + << "experimental btree bits=" << bits << " op=" << which); + switch (which) { + case 0: { + const size_t x = pos_i(rng); + EXPECT_EQ(rm.rank1(x), nv.rank1(x)); + } break; + case 1: { + const size_t x = pos_i(rng); + EXPECT_EQ(rm.rank0(x), nv.rank0(x)); + } break; + case 2: { + const size_t k = + std::uniform_int_distribution(0, ones + 3)(rng); + EXPECT_EQ(rm.select1(k), nv.select1(k)); + } break; + case 3: { + const size_t k = + std::uniform_int_distribution(0, zeros + 3)(rng); + EXPECT_EQ(rm.select0(k), nv.select0(k)); + } break; + case 4: { + const size_t x = + (N >= 2 ? std::uniform_int_distribution(0, N)(rng) : 0); + EXPECT_EQ(rm.rank10(x), nv.rank10(x)); + } break; + case 5: { + const size_t k = + std::uniform_int_distribution(0, pairs10 + 3)(rng); + EXPECT_EQ(rm.select10(k), nv.select10(k)); + } break; + case 6: { + const size_t x = pos_i(rng); + EXPECT_EQ(rm.excess(x), nv.excess(x)); + } break; + case 7: { + if (N == 0) { + break; + } + const size_t x = pos_i_nz(rng); + const int d = d_dist(rng); + EXPECT_EQ(rm.fwdsearch(x, d), nv.fwdsearch(x, d)); + } break; + case 8: { + if (N == 0) { + break; + } + const size_t x = pos_i_nz(rng); + const int d = d_dist(rng); + EXPECT_EQ(rm.bwdsearch(x, d), nv.bwdsearch(x, d)); + } break; + case 9: + if (N != 0) { + EXPECT_EQ(rm.range_min_query_pos(i, j), nv.range_min_query_pos(i, j)); + } + break; + case 10: + if (N != 0) { + EXPECT_EQ(rm.range_min_query_val(i, j), nv.range_min_query_val(i, j)); + } + break; + case 11: + if (N != 0) { + EXPECT_EQ(rm.mincount(i, j), nv.mincount(i, j)); + } + break; + case 12: + if (N != 0) { + const size_t count = nv.mincount(i, j); + const size_t k = count == 0 ? 1 + : std::uniform_int_distribution( + 1, count + 1)(rng); + EXPECT_EQ(rm.minselect(i, j, k), nv.minselect(i, j, k)); + } + break; + case 13: + if (N != 0) { + EXPECT_EQ(rm.range_max_query_pos(i, j), nv.range_max_query_pos(i, j)); + } + break; + case 14: + if (N != 0) { + EXPECT_EQ(rm.range_max_query_val(i, j), nv.range_max_query_val(i, j)); + } + break; + case 15: + if (N != 0) { + const size_t x = pos_i_nz(rng); + EXPECT_EQ(rm.close(x), nv.close(x)); + } + break; + case 16: + if (N != 0) { + const size_t x = pos_i_nz(rng); + EXPECT_EQ(rm.open(x), nv.open(x)); + } + break; + case 17: + if (N != 0) { + const size_t x = pos_i_nz(rng); + EXPECT_EQ(rm.enclose(x), nv.enclose(x)); + } + break; + } + } +} + +TEST(RmMBTreeExperimental, DifferentialRandom) { + std::mt19937_64 rng(kSeed); + std::uniform_int_distribution len_u(1, 4096); + for (size_t iter = 0; iter < 10; ++iter) { + run_btree_case_and_compare(random_bits(rng, (size_t)len_u(rng)), rng, 200); + } +} + +TEST(RmMBTreeExperimental, DifferentialBoundarySizes) { + std::mt19937_64 rng(kSeed); + for (size_t n : + std::array{1, 63, 64, 511, 512, 513, 1023, 1024, 2049}) { + run_btree_case_and_compare(random_bits(rng, n), rng, 200); + } +} + +TEST(RmMBTreeExperimental, SdslStyleParenthesesIndexing) { + const std::string bits = "11101001011000"; + auto words = pack_words_lsb_first(bits); + pixie::experimental::RmMBTree<> rm(std::span(words), + bits.size()); + + EXPECT_EQ(rm.close(0), bits.size() - 1); + EXPECT_EQ(rm.close(1), 6u); + EXPECT_EQ(rm.close(2), 3u); + EXPECT_EQ(rm.open(3), 2u); + EXPECT_EQ(rm.open(6), 1u); + EXPECT_EQ(rm.enclose(1), 0u); + EXPECT_EQ(rm.enclose(3), 2u); + EXPECT_EQ(rm.enclose(0), pixie::experimental::RmMBTree<>::npos); +} + +TEST(RmMBTreeExperimental, EmptyInput) { + std::vector words; + pixie::experimental::RmMBTree<> rm(std::span(words), + /*bit_count=*/0); + + EXPECT_EQ(rm.size(), 0u); + EXPECT_EQ(rm.rank1(0), 0u); + EXPECT_EQ(rm.rank0(0), 0u); + EXPECT_EQ(rm.select1(1), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.select0(1), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.rank10(0), 0u); + EXPECT_EQ(rm.select10(1), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.excess(0), 0); + EXPECT_EQ(rm.fwdsearch(0, 0), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.bwdsearch(0, 0), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.range_min_query_pos(0, 0), + pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.range_min_query_val(0, 0), 0); + EXPECT_EQ(rm.range_max_query_pos(0, 0), + pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.range_max_query_val(0, 0), 0); + EXPECT_EQ(rm.mincount(0, 0), 0u); + EXPECT_EQ(rm.minselect(0, 0, 1), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.close(0), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.open(0), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.enclose(0), pixie::experimental::RmMBTree<>::npos); +} + +TEST(RmMBTreeExperimental, SpanConstructorRejectsShortInputStorage) { + std::vector words(1, 0); + EXPECT_THROW( + (pixie::experimental::RmMBTree<>(std::span(words), + /*bit_count=*/65)), + std::invalid_argument); +} + +TEST(RmMBTreeExperimental, RankSelectIgnoresDirtyTrailingStorage) { + std::vector words = { + 0b101ull, std::numeric_limits::max()}; + pixie::experimental::RmMBTree<> rm(std::span(words), + /*bit_count=*/3); + + EXPECT_EQ(rm.rank1(3), 2u); + EXPECT_EQ(rm.rank1(128), 2u); + EXPECT_EQ(rm.rank0(3), 1u); + EXPECT_EQ(rm.rank0(128), 1u); + EXPECT_EQ(rm.select1(1), 0u); + EXPECT_EQ(rm.select1(2), 2u); + EXPECT_EQ(rm.select1(3), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.select0(1), 1u); + EXPECT_EQ(rm.select0(2), pixie::experimental::RmMBTree<>::npos); +} + +TEST(RmMBTreeExperimental, ParenthesesOnUnmatchedBoundaryBits) { + const std::string bits = "1"; + auto words = pack_words_lsb_first(bits); + pixie::experimental::RmMBTree<> rm(std::span(words), + bits.size()); + + EXPECT_EQ(rm.close(0), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.open(0), 0u); + EXPECT_EQ(rm.enclose(0), pixie::experimental::RmMBTree<>::npos); +} + +TEST(RmMBTreeExperimental, InvalidArgumentsGuards) { + const std::string bits = "101100"; + auto words = pack_words_lsb_first(bits); + pixie::experimental::RmMBTree<> rm(std::span(words), + bits.size()); + + EXPECT_EQ(rm.select1(0), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.select1(4), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.select0(0), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.select0(4), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.select10(0), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.select10(3), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.fwdsearch(bits.size(), 0), + pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.bwdsearch(0, 0), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.bwdsearch(bits.size() + 1, 0), + pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.range_min_query_pos(3, 2), + pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.range_min_query_pos(0, bits.size()), + pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.range_min_query_val(3, 2), 0); + EXPECT_EQ(rm.range_min_query_val(0, bits.size()), 0); + EXPECT_EQ(rm.range_max_query_pos(3, 2), + pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.range_max_query_pos(0, bits.size()), + pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.range_max_query_val(3, 2), 0); + EXPECT_EQ(rm.range_max_query_val(0, bits.size()), 0); + EXPECT_EQ(rm.mincount(3, 2), 0u); + EXPECT_EQ(rm.mincount(0, bits.size()), 0u); + EXPECT_EQ(rm.minselect(0, bits.size() - 1, 0), + pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.minselect(3, 2, 1), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.minselect(0, bits.size(), 1), + pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.close(bits.size()), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.open(bits.size()), pixie::experimental::RmMBTree<>::npos); + EXPECT_EQ(rm.enclose(bits.size()), pixie::experimental::RmMBTree<>::npos); +} + +TEST(RmMBTreeExperimental, FwdBwdSearchAcrossHighLevels) { + std::mt19937_64 rng(kSeed); + const size_t n = 512 * 32 * 8 + 4096; + const std::string bits = random_bits(rng, n); + auto words = pack_words_lsb_first(bits); + pixie::experimental::RmMBTree<> rm(std::span(words), + bits.size()); + NaiveRmM nv(bits); + + const std::array positions = { + 0, + 511, + 512, + 512 * 32 - 1, + 512 * 32, + 512 * 32 + 1, + 512 * 32 * 8 - 1, + 512 * 32 * 8, + 512 * 32 * 8 + 1, + n - 1, + }; + for (size_t position : positions) { + for (int delta : {-128, -17, -2, -1, 0, 1, 2, 17, 128}) { + SCOPED_TRACE(::testing::Message() + << "position=" << position << " delta=" << delta); + EXPECT_EQ(rm.fwdsearch(position, delta), nv.fwdsearch(position, delta)); + EXPECT_EQ(rm.bwdsearch(position + 1, delta), + nv.bwdsearch(position + 1, delta)); + } + } +} + +TEST(RmMBTreeExperimental, BwdSearchReturnsNodeBoundaryWithoutDescend) { + using Tree = pixie::experimental::RmMBTree<>; + + const auto check_boundary = [](size_t start, size_t target) { + const std::string bits(start, '1'); + auto words = pack_words_lsb_first(bits); + Tree rm(std::span(words), bits.size()); + NaiveRmM nv(bits); + + const int delta = static_cast(target) - static_cast(start); + SCOPED_TRACE(::testing::Message() + << "start=" << start << " target=" << target); + EXPECT_EQ(rm.bwdsearch(start, delta), target); + EXPECT_EQ(rm.bwdsearch(start, delta), nv.bwdsearch(start, delta)); + }; + + check_boundary(3 * Tree::kBlockBits, Tree::kBlockBits); + + const size_t low_span = Tree::kBlockBits * Tree::kLowFanout; + check_boundary(3 * low_span, low_span); + + const size_t high_span = low_span * Tree::kHighFanout; + check_boundary(3 * high_span, high_span); +} + +TEST(RmMBTreeExperimental, MinCountAndMinSelectAcrossNodeBoundaries) { + std::string bits; + bits.reserve(512 * 40 + 37); + for (size_t i = 0; i < 512 * 40 + 37; ++i) { + bits.push_back((i & 1) ? '1' : '0'); + } + + auto words = pack_words_lsb_first(bits); + pixie::experimental::RmMBTree<> rm(std::span(words), + bits.size()); + NaiveRmM nv(bits); + + const std::array, 4> ranges = { + std::pair{0, bits.size() - 1}, + std::pair{3, 512 * 33 + 5}, + std::pair{511, 512 * 35}, + std::pair{512 * 31 - 7, 512 * 40 + 11}, + }; + for (const auto& [left, right] : ranges) { + SCOPED_TRACE(::testing::Message() + << "range=[" << left << "," << right << "]"); + EXPECT_EQ(rm.range_min_query_pos(left, right), + nv.range_min_query_pos(left, right)); + EXPECT_EQ(rm.range_min_query_val(left, right), + nv.range_min_query_val(left, right)); + EXPECT_EQ(rm.mincount(left, right), nv.mincount(left, right)); + + const size_t count = nv.mincount(left, right); + for (size_t rank = 1; rank <= count; ++rank) { + EXPECT_EQ(rm.minselect(left, right, rank), + nv.minselect(left, right, rank)) + << "rank=" << rank; + } + EXPECT_EQ(rm.minselect(left, right, count + 1), nv.npos); + } +} + TEST(RmMTest, RankBasic) { std::vector bits = {0b10110}; pixie::RmMTree rm(std::span(bits), 5); diff --git a/src/tests/unittests.cpp b/src/tests/unittests.cpp index 77a1b9e..1c435d3 100644 --- a/src/tests/unittests.cpp +++ b/src/tests/unittests.cpp @@ -324,6 +324,81 @@ TEST(BitVectorTest, SelectZeroBasic) { EXPECT_EQ(bv.select0(8), 13); } +TEST(BitVectorTest, EmptyAndZeroRankGuards) { + BitVector default_bv; + EXPECT_EQ(default_bv.size(), 0u); + EXPECT_EQ(default_bv.rank(0), 0u); + EXPECT_EQ(default_bv.rank0(0), 0u); + EXPECT_EQ(default_bv.select(0), 0u); + EXPECT_EQ(default_bv.select(1), 0u); + EXPECT_EQ(default_bv.select0(0), 0u); + EXPECT_EQ(default_bv.select0(1), 0u); + EXPECT_EQ(default_bv.to_string(), ""); + + std::vector empty; + BitVector bv(std::span(empty), 0); + + EXPECT_EQ(bv.size(), 0u); + EXPECT_EQ(bv.rank(0), 0u); + EXPECT_EQ(bv.rank0(0), 0u); + EXPECT_EQ(bv.select(0), 0u); + EXPECT_EQ(bv.select(1), 0u); + EXPECT_EQ(bv.select0(0), 0u); + EXPECT_EQ(bv.select0(1), 0u); + EXPECT_EQ(bv.to_string(), ""); + + std::vector bits(1, 0b10110); + BitVector non_empty(std::span(bits), 5); + EXPECT_EQ(non_empty.select(0), 0u); + EXPECT_EQ(non_empty.select0(0), 0u); +} + +TEST(BitVectorTest, ExactShortSpanRankSelect) { + std::vector bits = {0b1100010110010110}; + BitVector bv(bits, 16); + + EXPECT_EQ(bv.rank(16), 8); + EXPECT_EQ(bv.rank0(16), 8); + EXPECT_EQ(bv.select(1), 1); + EXPECT_EQ(bv.select(8), 15); + EXPECT_EQ(bv.select(9), bv.size()); + EXPECT_EQ(bv.select0(1), 0); + EXPECT_EQ(bv.select0(8), 13); + EXPECT_EQ(bv.select0(9), bv.size()); +} + +TEST(BitVectorTest, ShortSpanUsesScalarRankSelectFallbacks) { + std::vector one_in_second_word = {0, 1}; + BitVector ones(std::span(one_in_second_word), 65); + EXPECT_EQ(ones.rank(64), 0u); + EXPECT_EQ(ones.rank(65), 1u); + EXPECT_EQ(ones.rank0(65), 64u); + EXPECT_EQ(ones.select(1), 64u); + + std::vector zero_in_second_word = { + std::numeric_limits::max(), 0}; + BitVector zeros(std::span(zero_in_second_word), 65); + EXPECT_EQ(zeros.rank(65), 64u); + EXPECT_EQ(zeros.rank0(64), 0u); + EXPECT_EQ(zeros.rank0(65), 1u); + EXPECT_EQ(zeros.select0(1), 64u); +} + +TEST(BitVectorTest, IgnoresDirtyPaddingAndTrailingWords) { + std::vector bits = {~uint64_t{0}, ~uint64_t{0}}; + BitVector bv(bits, 3); + + EXPECT_EQ(bv.size(), 3); + EXPECT_EQ(bv.rank(3), 3); + EXPECT_EQ(bv.rank(128), 3); + EXPECT_EQ(bv.rank0(3), 0); + EXPECT_EQ(bv.rank0(128), 0); + EXPECT_EQ(bv.select(1), 0); + EXPECT_EQ(bv.select(3), 2); + EXPECT_EQ(bv.select(4), bv.size()); + EXPECT_EQ(bv.select0(1), bv.size()); +} + TEST(BitVectorTest, MainRankZeroTest) { std::mt19937_64 rng(42); std::vector bits(65536 * 32);