Skip to content

Commit 2aed291

Browse files
LouisYRYJsmarter
andauthored
Final EKFAC implementation (#123)
* ekfac implementation done (untested) * remove unnecessary squeeze * add tkfac * fix claude issues * shampoo * minor fix * Add EKFAC tests and fix a couple of bugs (#125) * Fix mask bug and add batch size invariance test wih toy model The backward_hook was using g.reshape(-1, O) which includes padding positions in the covariance computation. This causes incorrect results when batches have different sequence lengths. Before this commit, the added test failed with: > FAILED tests/ekfac_tests/test_batch_size_invariance.py::test_trace_batch_invariant[seq_lengths1-20] - AssertionError: Scalars are not close! > > Expected 1.231401894309304 but got 0.8983965093439276. > Absolute difference: 0.33300538496537635 (up to 1e-4 allowed) > Relative difference: 0.27042786478102654 (up to 0.01 allowed) * Fix use_dataset_labels condition and add FIM accuracy test The condition `if not hessian_cfg.use_dataset_labels:` was inverted, causing the empirical Fisher (with dataset labels) to use sampled labels and vice versa. Add test_fim_accuracy.py which verifies that KFAC approximates the Fisher Information Matrix within tolerance for both empirical FIM (dataset labels) and true FIM (sampled labels). * Add ground truth ekfac tests This is still missing FSDP support and test_apply_ekfac.py from #68 Co-Authored-By: LouisYRYJ <[email protected]> * ekfac_tests/test_batch_size_invariance.py: Fix error thresholds when running on CPU * Cleanup EKFAC tests - Replace set_all_seeds by existing setup_reproducibility - Reuse approximate_hessians instead of doing something equivalent manually. * Add --token_batch_size option to EKFAC tests * Add --n_samples option to EKFAC tests Allow configuring the number of samples from pile-10k dataset via pytest command line option instead of hardcoding 100. The dataset directory is now named dynamically (e.g., pile_100_examples). * hessians: Fix distributed support and test it Restore the calls to dist.barrier that existed in #13, without them the process would hang when running with world_size > 1. For testing, we add _allocate_batches_world to compute the batches for the ground truth. The tests don't pass due to numerical errors, this is handled in the next commit by changing our comparison logic. * ekfac_tests: Use appropriate metrics for each comparison - Eigenvectors: Check |cosine_similarity| ≈ 1 per column, which naturally handles sign ambiguity (eigenvectors are only defined up to sign) - Covariances: Check relative Frobenius norm since values should match exactly - Eigenvalue corrections: Align signs based on eigenvector orientation, then check relative error (λ[i,j] transforms as sign_G[i] * sign_A[j]) - Also reenable CPU tests which pass after this change. * ekfac_tests: Relax thresholds for distributed runs With world_size > 1, floating-point reduction order differs between ground truth (single process) and distributed run, causing larger numerical differences in some layers. For eigenvectors, use average |cos_sim| instead of minimum - this tolerates occasional outlier eigenvectors while maintaining a stricter threshold (1e-3 vs 0.1 that would be needed for min). For eigenvalue corrections, use atol=0.2 when world_size > 1. * adjust test + normalize shampoo and tkfac * minor fixes, correct tensor handling in shampoo and tkfac, introduce apply_hessian (WIP) --------- Co-authored-by: Guillaume Martres <[email protected]>
1 parent adbc3d1 commit 2aed291

28 files changed

+3612
-63
lines changed

bergson/__main__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from simple_parsing import ArgumentParser, ConflictResolution
77

88
from .build import build
9-
from .config import IndexConfig, QueryConfig, ReduceConfig, ScoreConfig
9+
from .config import HessianConfig, IndexConfig, QueryConfig, ReduceConfig, ScoreConfig
10+
from .hessians.hessian_approximations import approximate_hessians
1011
from .query.query_index import query
1112
from .reduce import reduce
1213
from .score.score import score_dataset
@@ -98,11 +99,24 @@ def execute(self):
9899
query(self.query_cfg)
99100

100101

102+
@dataclass
103+
class Hessian:
104+
"""Approximate Hessian matrices using KFAC or EKFAC."""
105+
106+
hessian_cfg: HessianConfig
107+
index_cfg: IndexConfig
108+
109+
def execute(self):
110+
"""Compute Hessian approximation."""
111+
validate_run_path(self.index_cfg)
112+
approximate_hessians(self.index_cfg, self.hessian_cfg)
113+
114+
101115
@dataclass
102116
class Main:
103117
"""Routes to the subcommands."""
104118

105-
command: Union[Build, Query, Reduce, Score]
119+
command: Union[Build, Query, Reduce, Score, Hessian]
106120

107121
def execute(self):
108122
"""Run the script."""

bergson/collector/collector.py

Lines changed: 122 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import functools
2+
import hashlib
23
import os
34
from abc import ABC, abstractmethod
45
from contextlib import ContextDecorator, nullcontext
56
from dataclasses import astuple, dataclass, field
67
from fnmatch import fnmatchcase
78
from typing import Callable, Literal, Mapping, Optional
89

10+
import numpy as np
911
import torch
1012
import torch.distributed as dist
1113
import torch.nn as nn
@@ -24,15 +26,14 @@
2426
from tqdm.auto import tqdm
2527
from transformers import PreTrainedModel
2628

27-
from bergson.config import AttentionConfig, IndexConfig
29+
from bergson.config import AttentionConfig, HessianConfig, IndexConfig
2830
from bergson.data import pad_and_tensor
2931
from bergson.gradients import (
3032
GradientProcessor,
3133
LayerAdapter,
3234
)
3335
from bergson.utils.logger import get_logger
3436
from bergson.utils.peft import set_peft_enabled
35-
from bergson.utils.utils import create_projection_matrix
3637

3738

3839
@dataclass
@@ -78,6 +79,7 @@ class HookCollectorBase(ContextDecorator, ABC):
7879
Optional configuration specifying how to split up the attention module gradients
7980
into per-head gradients. See also bergson.config.AttentionConfig.
8081
"""
82+
logger = get_logger("HookCollectorBase", level="INFO")
8183

8284
def __post_init__(
8385
self,
@@ -256,6 +258,28 @@ def projection(
256258
self.processor._projection_matrices[key] = A
257259
return A
258260

261+
def with_batch(self, valid_mask: Tensor | None = None) -> "HookCollectorBase":
262+
"""
263+
Set the current batch indices and valid mask before entering the context.
264+
265+
This allows hooks to access batch indices and valid mask during
266+
forward/backward passes.
267+
Usage:
268+
with collector.with_batch(indices, valid_mask):
269+
# forward/backward pass
270+
# hooks can access self._current_indices and self._current_valid_mask
271+
272+
Args:
273+
indices: List of data indices in the current batch.
274+
valid_mask: Optional boolean tensor of shape [batch_size, seq_len]
275+
indicating which positions have valid labels for loss computation.
276+
277+
Returns:
278+
self, for use as a context manager.
279+
"""
280+
self._current_valid_mask = valid_mask
281+
return self
282+
259283
def __enter__(self):
260284
"""Register forward and backward hooks on all target modules."""
261285
for name in self.target_info:
@@ -484,15 +508,23 @@ def run_with_collector_hooks(
484508
):
485509
batch = self.data[indices]
486510

511+
# Compute padded tensors and valid_mask before entering context
512+
x, y, valid_mask = pad_and_tensor(
513+
batch["input_ids"],
514+
labels=batch.get("labels"),
515+
device=self.model.device,
516+
)
517+
total_processed += valid_mask.sum()
518+
487519
with (
488-
self.collector,
520+
self.collector.with_batch(valid_mask),
489521
(
490522
record_function(f"step_{step}")
491523
if self.cfg.profile
492524
else nullcontext()
493525
),
494526
):
495-
losses = self.forward_backward(self.model, batch)
527+
losses = self.forward_backward(self.model, x, y, batch)
496528

497529
# TODO: currently builder also calls torch.cuda.synchronize
498530
torch.cuda.synchronize() if torch.cuda.is_available() else None
@@ -503,11 +535,17 @@ def run_with_collector_hooks(
503535
step += 1
504536

505537
self.collector.process_batch(indices, losses=losses)
506-
total_processed += len(indices)
507538

508539
self.collector.teardown()
540+
509541
if dist.is_initialized():
510542
dist.all_reduce(total_processed, op=dist.ReduceOp.SUM)
543+
544+
if self.rank == 0:
545+
torch.save(
546+
total_processed,
547+
os.path.join(self.cfg.partial_run_path, "total_processed.pt"),
548+
)
511549
self.logger.info(f"Total processed: {total_processed.item()}")
512550

513551

@@ -523,18 +561,17 @@ def fwd_bwd_factory(cfg: IndexConfig) -> Callable:
523561
summed loss.
524562
525563
Returns:
526-
A callable fwd_bwd(model, batch) -> Tensor that performs a forward pass and
527-
backward pass, returning the per-sample losses.
528-
The batch must contain "input_ids" and optionally "labels" and "advantage".
564+
A callable fwd_bwd(model, x, y, batch) -> Tensor that performs a forward pass
565+
and backward pass, returning the per-sample losses.
566+
Args:
567+
model: The model to run forward/backward on.
568+
x: Padded input token ids tensor of shape [batch_size, seq_len].
569+
y: Padded label tensor of shape [batch_size, seq_len] with -100 for padding.
570+
batch: Original batch dict, used only for "advantage" if present.
529571
Returns a tensor of shape [batch_size] with one loss value per sample.
530572
"""
531573

532-
def fwd_bwd(model, batch):
533-
x, y = pad_and_tensor(
534-
batch["input_ids"], # type: ignore
535-
labels=batch.get("labels"), # type: ignore
536-
device=model.device,
537-
)
574+
def fwd_bwd(model, x: Tensor, y: Tensor, batch: dict):
538575
logits = model(x).logits[:, :-1]
539576
masks = y[:, 1:] != -100
540577
denoms = (
@@ -571,3 +608,74 @@ def fwd_bwd(model, batch):
571608
return losses
572609

573610
return fwd_bwd
611+
612+
613+
def fwd_bwd_hessian_factory(
614+
index_cfg: IndexConfig, hessian_cfg: HessianConfig
615+
) -> Callable:
616+
def fwd_bwd_hessian(model, x: Tensor, y: Tensor, batch: dict):
617+
logits = model(x).logits[:, :-1]
618+
masks = y[:, 1:] != -100
619+
denoms = (
620+
masks.sum(dim=1, dtype=model.dtype)
621+
if index_cfg.loss_reduction == "mean"
622+
else 1.0
623+
)
624+
if hessian_cfg.use_dataset_labels:
625+
losses = F.cross_entropy(
626+
logits.reshape(-1, logits.size(-1)),
627+
y[:, 1:].flatten(),
628+
reduction="none",
629+
).reshape_as(y[:, 1:])
630+
losses = losses.sum(1) / denoms
631+
else:
632+
with torch.no_grad():
633+
probs = F.softmax(logits, dim=-1)
634+
sampled_tokens = torch.multinomial(
635+
probs.reshape(-1, probs.size(-1)),
636+
num_samples=1,
637+
replacement=True,
638+
).reshape_as(y[:, 1:])
639+
losses = F.cross_entropy(
640+
logits.reshape(-1, logits.size(-1)),
641+
sampled_tokens.flatten(),
642+
reduction="none",
643+
).reshape_as(y[:, 1:])
644+
losses = losses.sum(1) / denoms
645+
646+
losses.sum().backward()
647+
model.zero_grad()
648+
649+
return losses
650+
651+
return fwd_bwd_hessian
652+
653+
654+
def create_projection_matrix(
655+
identifier: str,
656+
m: int,
657+
n: int,
658+
dtype: torch.dtype,
659+
device: torch.device,
660+
projection_type: Literal["normal", "rademacher"] = "normal",
661+
) -> Tensor:
662+
"""Create a projection matrix deterministically based on identifier and side."""
663+
# Seed the PRNG with the name of the layer and what "side" we are projecting
664+
message = bytes(identifier, "utf-8")
665+
digest = hashlib.md5(message).digest()
666+
seed = int.from_bytes(digest, byteorder="big") % (2**63 - 1)
667+
668+
if projection_type == "normal":
669+
prng = torch.Generator(device).manual_seed(seed)
670+
A = torch.randn(m, n, device=device, dtype=dtype, generator=prng)
671+
elif projection_type == "rademacher":
672+
numpy_rng = np.random.Generator(np.random.PCG64(seed))
673+
random_bytes = numpy_rng.bytes((m * n + 7) // 8)
674+
random_bytes = np.frombuffer(random_bytes, dtype=np.uint8)
675+
A = np.unpackbits(random_bytes)[: m * n].reshape((m, n))
676+
A = torch.from_numpy(A).to(device, dtype=dtype)
677+
A = A.add_(-0.5).mul_(2)
678+
else:
679+
raise ValueError(f"Unknown projection type: {projection_type}")
680+
A /= A.norm(dim=1, keepdim=True)
681+
return A

bergson/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,24 @@ class ReduceConfig:
302302
"""Whether to unit normalize the gradients before reducing them."""
303303

304304

305+
@dataclass
306+
class HessianConfig:
307+
"""Config for reducing the gradients."""
308+
309+
method: Literal["kfac", "tkfac", "shampoo"] = "kfac"
310+
"""Method for approximating the Hessian."""
311+
312+
ev_correction: bool = False
313+
"""Whether to additionally compute eigenvalue correction."""
314+
315+
hessian_dtype: Literal["auto", "bf16", "fp16", "fp32"] = "auto"
316+
"""Precision (dtype) to use for the Hessian approximation."""
317+
318+
use_dataset_labels: bool = False
319+
"""Whether to use dataset labels for Hessian (empirical Fisher) approximation.
320+
If false, the model predictions will be used."""
321+
322+
305323
@dataclass
306324
class FaissConfig:
307325
"""Configuration for FAISS index."""

bergson/data.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,23 @@ def allocate_batches(
7676
"""
7777
rank = dist.get_rank() if dist.is_initialized() else 0
7878
world_size = dist.get_world_size() if dist.is_initialized() else 1
79+
(batches,) = _allocate_batches_world(doc_lengths, N, world_size, seed, ranks=[rank])
80+
return batches
81+
82+
83+
def _allocate_batches_world(
84+
doc_lengths: list[int],
85+
N: int,
86+
world_size: int,
87+
seed: int = 42,
88+
ranks: list[int] | None = None,
89+
) -> list[list[list[int]]]:
90+
"""Lower-level version of allocate_batches that returns batches for specified ranks.
91+
92+
If ranks is None, returns batches for all ranks.
93+
"""
94+
if ranks is None:
95+
ranks = list(range(world_size))
7996
if len(doc_lengths) < world_size:
8097
raise RuntimeError("Not enough documents to distribute across workers.")
8198

@@ -162,11 +179,12 @@ def allocate_batches(
162179
# Sanity: equal # of batches per worker
163180
assert len({len(b) for b in allocation}) == 1
164181

165-
# Break any systematic ordering of batches
166-
random.seed(seed)
167-
random.shuffle(allocation[rank])
182+
# Break any systematic ordering of batches (shuffle only requested ranks)
183+
for rank in ranks:
184+
random.seed(seed)
185+
random.shuffle(allocation[rank])
168186

169-
return allocation[rank]
187+
return [allocation[rank] for rank in ranks]
170188

171189

172190
def create_index(
@@ -466,7 +484,7 @@ def pad_and_tensor(
466484
padding_value: int = 0,
467485
dtype: torch.dtype | None = torch.long,
468486
device: torch.device | None = None,
469-
) -> tuple[torch.Tensor, torch.Tensor]:
487+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
470488
"""
471489
Pad a list of sequences to the same length and convert them to tensors.
472490
Returns a tuple of padded sequences and labels. The labels are the same as the
@@ -485,7 +503,12 @@ def pad_and_tensor(
485503
# convert to tensor
486504
padded_tokens = torch.tensor(padded, dtype=dtype, device=device)
487505
padded_labels = torch.tensor(labels, dtype=dtype, device=device)
488-
return padded_tokens, padded_labels
506+
# Compute valid_masks: position i is valid if labels[i+1] != -100
507+
N, S = padded_tokens.shape
508+
valid_masks = torch.zeros(N, S, dtype=torch.bool, device=device)
509+
valid_masks[:, :-1] = padded_labels[:, 1:] != -100
510+
511+
return padded_tokens, padded_labels, valid_masks
489512

490513

491514
def tokenize(batch: dict, *, args: DataConfig, tokenizer):

0 commit comments

Comments
 (0)