Skip to content

Commit b48ce8d

Browse files
authored
Merge pull request #80 from EleutherAI/add-normalizers
Add back optimizer-aware gradients
2 parents afd26bc + 833678c commit b48ce8d

File tree

12 files changed

+479
-121
lines changed

12 files changed

+479
-121
lines changed

bergson/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__version__ = "0.4.6"
22

33
from .collection import collect_gradients
4+
from .collector.gradient_collectors import GradientCollector
45
from .config import (
56
AttentionConfig,
67
DataConfig,
@@ -11,6 +12,7 @@
1112
)
1213
from .data import load_gradient_dataset, load_gradients
1314
from .gradients import GradientProcessor
15+
from .normalizer.fit_normalizers import fit_normalizers
1416
from .query.attributor import Attributor
1517
from .query.faiss_index import FaissConfig
1618
from .score.scorer import Scorer
@@ -20,10 +22,12 @@
2022
"collect_gradients",
2123
"load_gradients",
2224
"load_gradient_dataset",
25+
"fit_normalizers",
2326
"Attributor",
2427
"FaissConfig",
2528
"FiniteDiff",
2629
"GradientProcessor",
30+
"GradientCollector",
2731
"IndexConfig",
2832
"DataConfig",
2933
"AttentionConfig",

bergson/__main__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,10 @@ def execute(self):
107107
self.command.execute()
108108

109109

110-
def get_parser():
111-
"""Get the argument parser. Used for documentation generation."""
112-
parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT)
113-
parser.add_arguments(Main, dest="prog")
114-
return parser
115-
116-
117110
def main(args: Optional[list[str]] = None):
118111
"""Parse CLI arguments and dispatch to the selected subcommand."""
119-
parser = get_parser()
112+
parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT)
113+
parser.add_arguments(Main, dest="prog")
120114
prog: Main = parser.parse_args(args=args).prog
121115
prog.execute()
122116

bergson/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def build_worker(
5656
)
5757

5858
model, target_modules = setup_model_and_peft(cfg, rank)
59-
processor = create_processor(cfg, rank)
59+
processor = create_processor(model, ds, cfg, rank, target_modules)
6060

6161
attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules}
6262

bergson/collector/collector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,10 +477,10 @@ def run_with_collector_hooks(
477477
total_processed = torch.tensor(0, device=self.model.device)
478478
prof = self._setup_profiler()
479479
step = 0
480-
481480
with prof:
482481
for indices in tqdm(
483-
self.batches, disable=self.rank != 0, desc=f"Computing {desc}"
482+
self.batches,
483+
desc=f"Computing {desc}",
484484
):
485485
batch = self.data[indices]
486486

@@ -503,6 +503,7 @@ def run_with_collector_hooks(
503503
step += 1
504504

505505
self.collector.process_batch(indices, losses=losses)
506+
total_processed += len(indices)
506507

507508
self.collector.teardown()
508509
if dist.is_initialized():

bergson/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class IndexConfig:
9595
processor_path: str = ""
9696
"""Path to a precomputed processor."""
9797

98+
normalizer: Literal["adafactor", "adam", "none"] = "none"
99+
"""Type of normalizer to use for the gradients."""
100+
98101
skip_preconditioners: bool = False
99102
"""Whether to skip computing preconditioners for the gradients."""
100103

bergson/gradients.py

Lines changed: 102 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -58,108 +58,6 @@ def state_dict(self) -> dict[str, str | Tensor]:
5858
}
5959

6060

61-
@dataclass
62-
class AdafactorNormalizer(Normalizer):
63-
"""
64-
Row and column sums of second moments of gradients for a matrix-valued parameter.
65-
"""
66-
67-
row: Tensor # shape [O]
68-
col: Tensor # shape [I]
69-
70-
def __post_init__(self):
71-
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
72-
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"
73-
74-
@torch.compile
75-
def normalize_(
76-
self,
77-
grad: Tensor,
78-
eps: float = 1e-30,
79-
) -> Tensor:
80-
"""
81-
Normalize the row and column sums by adding a small epsilon.
82-
83-
Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They
84-
recommend 1e-30, but we use 1e-16 for extra numerical stability.
85-
"""
86-
# We follow the Adafactor implementation in the tensor2tensor repo, which is
87-
# different from the paper and from the PyTorch implementation. First add eps
88-
# to ensure these second moments are sufficiently far from zero. Then we don't
89-
# need to worry about numerical stability anywhere else, and we don't need to
90-
# materialize the outer product at any point.
91-
r, c = self.row.add(eps), self.col.add(eps)
92-
93-
# This is the denominator for V, the rank-one matrix of second moment estimates:
94-
# V = torch.outer(r, c) / denom
95-
# V_ij = r_i * c_j / denom
96-
# But we want to (implicitly) take the Hadamard product with the elementwise
97-
# reciprocal square root of V:
98-
# (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt()
99-
denom = r.mean()
100-
101-
# Hadamard product with a rank-one matrix ab^T is the same as left-multiplying
102-
# by diag(a) and right-multiplying by diag(b). In this case we can represent
103-
# the elementwise reciprocal square root of V as ab^T where:
104-
# a = denom.sqrt() * r.rsqrt() and b = c.rsqrt()
105-
a = denom.sqrt() * r.rsqrt_() # shape [O]
106-
b = c.rsqrt_()
107-
108-
# Implicitly do the Hadamard product
109-
grad *= a[:, None] # [N, O] * [O] → [N, O]
110-
grad *= b[None, :]
111-
return grad
112-
113-
def to_adam(self) -> "AdamNormalizer":
114-
"""
115-
Convert this Adafactor normalizer to an Adam normalizer by materializing the
116-
rank-one second moment matrix.
117-
"""
118-
# Compute the second moment matrix as a square matrix of shape [O, I]
119-
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
120-
# add it outside the square root. This could cause infs though if there are
121-
# any exactly zero rows or columns, so we should be careful.
122-
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
123-
return AdamNormalizer(avg_sq=avg_sq)
124-
125-
126-
@dataclass
127-
class AdamNormalizer(Normalizer):
128-
"""
129-
Contains the second moments of the gradients.
130-
"""
131-
132-
avg_sq: Tensor
133-
134-
@torch.compile
135-
def normalize_(
136-
self,
137-
grad: Tensor,
138-
eps: float = 1e-8,
139-
) -> Tensor:
140-
"""Normalize the gradients by the square root of the second moments."""
141-
# Adam-style epsilon is added outside the square root
142-
denom = self.avg_sq.sqrt()
143-
return grad.div_(denom.add_(eps))
144-
145-
def to_adafactor(self) -> AdafactorNormalizer:
146-
"""
147-
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
148-
I-divergence (generalized Kullback-Leibler divergence) between the original
149-
and the factored second moments.
150-
"""
151-
# We assume avg_sq is a square matrix of shape [O, I]
152-
assert (
153-
self.avg_sq.ndim == 2
154-
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
155-
156-
# Compute row and column means
157-
return AdafactorNormalizer(
158-
row=self.avg_sq.mean(dim=1), # shape [O]
159-
col=self.avg_sq.mean(dim=0), # shape [I]
160-
)
161-
162-
16361
@dataclass
16462
class GradientProcessor:
16563
"""Configuration for processing and compressing gradients."""
@@ -317,3 +215,105 @@ def out_attr(layer: nn.Module) -> str:
317215
return "out_channels"
318216
case _:
319217
raise ValueError(f"Unsupported layer type: {type(layer)}")
218+
219+
220+
@dataclass
221+
class AdafactorNormalizer(Normalizer):
222+
"""
223+
Row and column sums of second moments of gradients for a matrix-valued parameter.
224+
"""
225+
226+
row: Tensor # shape [O]
227+
col: Tensor # shape [I]
228+
229+
def __post_init__(self):
230+
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
231+
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"
232+
233+
@torch.compile
234+
def normalize_(
235+
self,
236+
grad: Tensor,
237+
eps: float = 1e-30,
238+
) -> Tensor:
239+
"""
240+
Normalize the row and column sums by adding a small epsilon.
241+
242+
Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They
243+
recommend 1e-30, but we use 1e-16 for extra numerical stability.
244+
"""
245+
# We follow the Adafactor implementation in the tensor2tensor repo, which is
246+
# different from the paper and from the PyTorch implementation. First add eps
247+
# to ensure these second moments are sufficiently far from zero. Then we don't
248+
# need to worry about numerical stability anywhere else, and we don't need to
249+
# materialize the outer product at any point.
250+
r, c = self.row.add(eps), self.col.add(eps)
251+
252+
# This is the denominator for V, the rank-one matrix of second moment estimates:
253+
# V = torch.outer(r, c) / denom
254+
# V_ij = r_i * c_j / denom
255+
# But we want to (implicitly) take the Hadamard product with the elementwise
256+
# reciprocal square root of V:
257+
# (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt()
258+
denom = r.mean()
259+
260+
# Hadamard product with a rank-one matrix ab^T is the same as left-multiplying
261+
# by diag(a) and right-multiplying by diag(b). In this case we can represent
262+
# the elementwise reciprocal square root of V as ab^T where:
263+
# a = denom.sqrt() * r.rsqrt() and b = c.rsqrt()
264+
a = denom.sqrt() * r.rsqrt_() # shape [O]
265+
b = c.rsqrt_()
266+
267+
# Implicitly do the Hadamard product
268+
grad *= a[:, None] # [N, O] * [O] → [N, O]
269+
grad *= b[None, :]
270+
return grad
271+
272+
def to_adam(self) -> "AdamNormalizer":
273+
"""
274+
Convert this Adafactor normalizer to an Adam normalizer by materializing the
275+
rank-one second moment matrix.
276+
"""
277+
# Compute the second moment matrix as a square matrix of shape [O, I]
278+
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
279+
# add it outside the square root. This could cause infs though if there are
280+
# any exactly zero rows or columns, so we should be careful.
281+
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
282+
return AdamNormalizer(avg_sq=avg_sq)
283+
284+
285+
@dataclass
286+
class AdamNormalizer(Normalizer):
287+
"""
288+
Contains the second moments of the gradients.
289+
"""
290+
291+
avg_sq: Tensor
292+
293+
@torch.compile
294+
def normalize_(
295+
self,
296+
grad: Tensor,
297+
eps: float = 1e-8,
298+
) -> Tensor:
299+
"""Normalize the gradients by the square root of the second moments."""
300+
# Adam-style epsilon is added outside the square root
301+
denom = self.avg_sq.sqrt()
302+
return grad.div_(denom.add_(eps))
303+
304+
def to_adafactor(self) -> AdafactorNormalizer:
305+
"""
306+
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
307+
I-divergence (generalized Kullback-Leibler divergence) between the original
308+
and the factored second moments.
309+
"""
310+
# We assume avg_sq is a square matrix of shape [O, I]
311+
assert (
312+
self.avg_sq.ndim == 2
313+
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
314+
315+
# Compute row and column means
316+
return AdafactorNormalizer(
317+
row=self.avg_sq.mean(dim=1), # shape [O]
318+
col=self.avg_sq.mean(dim=0), # shape [I]
319+
)

bergson/normalizer/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)