@@ -58,108 +58,6 @@ def state_dict(self) -> dict[str, str | Tensor]:
5858 }
5959
6060
61- @dataclass
62- class AdafactorNormalizer (Normalizer ):
63- """
64- Row and column sums of second moments of gradients for a matrix-valued parameter.
65- """
66-
67- row : Tensor # shape [O]
68- col : Tensor # shape [I]
69-
70- def __post_init__ (self ):
71- assert self .row .ndim == 1 , f"Expected 1D tensor for row, got { self .row .ndim } D"
72- assert self .col .ndim == 1 , f"Expected 1D tensor for col, got { self .col .ndim } D"
73-
74- @torch .compile
75- def normalize_ (
76- self ,
77- grad : Tensor ,
78- eps : float = 1e-30 ,
79- ) -> Tensor :
80- """
81- Normalize the row and column sums by adding a small epsilon.
82-
83- Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They
84- recommend 1e-30, but we use 1e-16 for extra numerical stability.
85- """
86- # We follow the Adafactor implementation in the tensor2tensor repo, which is
87- # different from the paper and from the PyTorch implementation. First add eps
88- # to ensure these second moments are sufficiently far from zero. Then we don't
89- # need to worry about numerical stability anywhere else, and we don't need to
90- # materialize the outer product at any point.
91- r , c = self .row .add (eps ), self .col .add (eps )
92-
93- # This is the denominator for V, the rank-one matrix of second moment estimates:
94- # V = torch.outer(r, c) / denom
95- # V_ij = r_i * c_j / denom
96- # But we want to (implicitly) take the Hadamard product with the elementwise
97- # reciprocal square root of V:
98- # (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt()
99- denom = r .mean ()
100-
101- # Hadamard product with a rank-one matrix ab^T is the same as left-multiplying
102- # by diag(a) and right-multiplying by diag(b). In this case we can represent
103- # the elementwise reciprocal square root of V as ab^T where:
104- # a = denom.sqrt() * r.rsqrt() and b = c.rsqrt()
105- a = denom .sqrt () * r .rsqrt_ () # shape [O]
106- b = c .rsqrt_ ()
107-
108- # Implicitly do the Hadamard product
109- grad *= a [:, None ] # [N, O] * [O] → [N, O]
110- grad *= b [None , :]
111- return grad
112-
113- def to_adam (self ) -> "AdamNormalizer" :
114- """
115- Convert this Adafactor normalizer to an Adam normalizer by materializing the
116- rank-one second moment matrix.
117- """
118- # Compute the second moment matrix as a square matrix of shape [O, I]
119- # NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
120- # add it outside the square root. This could cause infs though if there are
121- # any exactly zero rows or columns, so we should be careful.
122- avg_sq = torch .outer (self .row , self .col ) / self .row .mean ()
123- return AdamNormalizer (avg_sq = avg_sq )
124-
125-
126- @dataclass
127- class AdamNormalizer (Normalizer ):
128- """
129- Contains the second moments of the gradients.
130- """
131-
132- avg_sq : Tensor
133-
134- @torch .compile
135- def normalize_ (
136- self ,
137- grad : Tensor ,
138- eps : float = 1e-8 ,
139- ) -> Tensor :
140- """Normalize the gradients by the square root of the second moments."""
141- # Adam-style epsilon is added outside the square root
142- denom = self .avg_sq .sqrt ()
143- return grad .div_ (denom .add_ (eps ))
144-
145- def to_adafactor (self ) -> AdafactorNormalizer :
146- """
147- Convert this Adam normalizer to an Adafactor normalizer, minimizing the
148- I-divergence (generalized Kullback-Leibler divergence) between the original
149- and the factored second moments.
150- """
151- # We assume avg_sq is a square matrix of shape [O, I]
152- assert (
153- self .avg_sq .ndim == 2
154- ), f"Expected 2D tensor for avg_sq, got { self .avg_sq .ndim } D"
155-
156- # Compute row and column means
157- return AdafactorNormalizer (
158- row = self .avg_sq .mean (dim = 1 ), # shape [O]
159- col = self .avg_sq .mean (dim = 0 ), # shape [I]
160- )
161-
162-
16361@dataclass
16462class GradientProcessor :
16563 """Configuration for processing and compressing gradients."""
@@ -317,3 +215,105 @@ def out_attr(layer: nn.Module) -> str:
317215 return "out_channels"
318216 case _:
319217 raise ValueError (f"Unsupported layer type: { type (layer )} " )
218+
219+
220+ @dataclass
221+ class AdafactorNormalizer (Normalizer ):
222+ """
223+ Row and column sums of second moments of gradients for a matrix-valued parameter.
224+ """
225+
226+ row : Tensor # shape [O]
227+ col : Tensor # shape [I]
228+
229+ def __post_init__ (self ):
230+ assert self .row .ndim == 1 , f"Expected 1D tensor for row, got { self .row .ndim } D"
231+ assert self .col .ndim == 1 , f"Expected 1D tensor for col, got { self .col .ndim } D"
232+
233+ @torch .compile
234+ def normalize_ (
235+ self ,
236+ grad : Tensor ,
237+ eps : float = 1e-30 ,
238+ ) -> Tensor :
239+ """
240+ Normalize the row and column sums by adding a small epsilon.
241+
242+ Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They
243+ recommend 1e-30, but we use 1e-16 for extra numerical stability.
244+ """
245+ # We follow the Adafactor implementation in the tensor2tensor repo, which is
246+ # different from the paper and from the PyTorch implementation. First add eps
247+ # to ensure these second moments are sufficiently far from zero. Then we don't
248+ # need to worry about numerical stability anywhere else, and we don't need to
249+ # materialize the outer product at any point.
250+ r , c = self .row .add (eps ), self .col .add (eps )
251+
252+ # This is the denominator for V, the rank-one matrix of second moment estimates:
253+ # V = torch.outer(r, c) / denom
254+ # V_ij = r_i * c_j / denom
255+ # But we want to (implicitly) take the Hadamard product with the elementwise
256+ # reciprocal square root of V:
257+ # (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt()
258+ denom = r .mean ()
259+
260+ # Hadamard product with a rank-one matrix ab^T is the same as left-multiplying
261+ # by diag(a) and right-multiplying by diag(b). In this case we can represent
262+ # the elementwise reciprocal square root of V as ab^T where:
263+ # a = denom.sqrt() * r.rsqrt() and b = c.rsqrt()
264+ a = denom .sqrt () * r .rsqrt_ () # shape [O]
265+ b = c .rsqrt_ ()
266+
267+ # Implicitly do the Hadamard product
268+ grad *= a [:, None ] # [N, O] * [O] → [N, O]
269+ grad *= b [None , :]
270+ return grad
271+
272+ def to_adam (self ) -> "AdamNormalizer" :
273+ """
274+ Convert this Adafactor normalizer to an Adam normalizer by materializing the
275+ rank-one second moment matrix.
276+ """
277+ # Compute the second moment matrix as a square matrix of shape [O, I]
278+ # NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
279+ # add it outside the square root. This could cause infs though if there are
280+ # any exactly zero rows or columns, so we should be careful.
281+ avg_sq = torch .outer (self .row , self .col ) / self .row .mean ()
282+ return AdamNormalizer (avg_sq = avg_sq )
283+
284+
285+ @dataclass
286+ class AdamNormalizer (Normalizer ):
287+ """
288+ Contains the second moments of the gradients.
289+ """
290+
291+ avg_sq : Tensor
292+
293+ @torch .compile
294+ def normalize_ (
295+ self ,
296+ grad : Tensor ,
297+ eps : float = 1e-8 ,
298+ ) -> Tensor :
299+ """Normalize the gradients by the square root of the second moments."""
300+ # Adam-style epsilon is added outside the square root
301+ denom = self .avg_sq .sqrt ()
302+ return grad .div_ (denom .add_ (eps ))
303+
304+ def to_adafactor (self ) -> AdafactorNormalizer :
305+ """
306+ Convert this Adam normalizer to an Adafactor normalizer, minimizing the
307+ I-divergence (generalized Kullback-Leibler divergence) between the original
308+ and the factored second moments.
309+ """
310+ # We assume avg_sq is a square matrix of shape [O, I]
311+ assert (
312+ self .avg_sq .ndim == 2
313+ ), f"Expected 2D tensor for avg_sq, got { self .avg_sq .ndim } D"
314+
315+ # Compute row and column means
316+ return AdafactorNormalizer (
317+ row = self .avg_sq .mean (dim = 1 ), # shape [O]
318+ col = self .avg_sq .mean (dim = 0 ), # shape [I]
319+ )
0 commit comments