Commit 2aed291
Final EKFAC implementation (#123)
* ekfac implementation done (untested)
* remove unnecessary squeeze
* add tkfac
* fix claude issues
* shampoo
* minor fix
* Add EKFAC tests and fix a couple of bugs (#125)
* Fix mask bug and add batch size invariance test wih toy model
The backward_hook was using g.reshape(-1, O) which includes padding
positions in the covariance computation. This causes incorrect results
when batches have different sequence lengths.
Before this commit, the added test failed with:
> FAILED tests/ekfac_tests/test_batch_size_invariance.py::test_trace_batch_invariant[seq_lengths1-20] - AssertionError: Scalars are not close!
>
> Expected 1.231401894309304 but got 0.8983965093439276.
> Absolute difference: 0.33300538496537635 (up to 1e-4 allowed)
> Relative difference: 0.27042786478102654 (up to 0.01 allowed)
* Fix use_dataset_labels condition and add FIM accuracy test
The condition `if not hessian_cfg.use_dataset_labels:` was inverted,
causing the empirical Fisher (with dataset labels) to use sampled
labels and vice versa.
Add test_fim_accuracy.py which verifies that KFAC approximates the
Fisher Information Matrix within tolerance for both empirical FIM
(dataset labels) and true FIM (sampled labels).
* Add ground truth ekfac tests
This is still missing FSDP support and test_apply_ekfac.py from
#68
Co-Authored-By: LouisYRYJ <[email protected]>
* ekfac_tests/test_batch_size_invariance.py: Fix error thresholds when running on CPU
* Cleanup EKFAC tests
- Replace set_all_seeds by existing setup_reproducibility
- Reuse approximate_hessians instead of doing something
equivalent manually.
* Add --token_batch_size option to EKFAC tests
* Add --n_samples option to EKFAC tests
Allow configuring the number of samples from pile-10k dataset via
pytest command line option instead of hardcoding 100. The dataset
directory is now named dynamically (e.g., pile_100_examples).
* hessians: Fix distributed support and test it
Restore the calls to dist.barrier that existed in
#13, without them the process would
hang when running with world_size > 1.
For testing, we add _allocate_batches_world to compute the batches for the
ground truth. The tests don't pass due to numerical errors, this is handled in
the next commit by changing our comparison logic.
* ekfac_tests: Use appropriate metrics for each comparison
- Eigenvectors: Check |cosine_similarity| ≈ 1 per column, which naturally
handles sign ambiguity (eigenvectors are only defined up to sign)
- Covariances: Check relative Frobenius norm since values should match exactly
- Eigenvalue corrections: Align signs based on eigenvector orientation, then
check relative error (λ[i,j] transforms as sign_G[i] * sign_A[j])
- Also reenable CPU tests which pass after this change.
* ekfac_tests: Relax thresholds for distributed runs
With world_size > 1, floating-point reduction order differs between ground
truth (single process) and distributed run, causing larger numerical
differences in some layers.
For eigenvectors, use average |cos_sim| instead of minimum - this tolerates
occasional outlier eigenvectors while maintaining a stricter threshold
(1e-3 vs 0.1 that would be needed for min).
For eigenvalue corrections, use atol=0.2 when world_size > 1.
* adjust test + normalize shampoo and tkfac
* minor fixes, correct tensor handling in shampoo and tkfac, introduce apply_hessian (WIP)
---------
Co-authored-by: Guillaume Martres <[email protected]>1 parent adbc3d1 commit 2aed291
File tree
28 files changed
+3612
-63
lines changed- bergson
- collector
- hessians
- utils
- examples
- tests/ekfac_tests
- ground_truth
28 files changed
+3612
-63
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | | - | |
| 9 | + | |
| 10 | + | |
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
| |||
98 | 99 | | |
99 | 100 | | |
100 | 101 | | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
101 | 115 | | |
102 | 116 | | |
103 | 117 | | |
104 | 118 | | |
105 | | - | |
| 119 | + | |
106 | 120 | | |
107 | 121 | | |
108 | 122 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
5 | 6 | | |
6 | 7 | | |
7 | 8 | | |
8 | 9 | | |
| 10 | + | |
9 | 11 | | |
10 | 12 | | |
11 | 13 | | |
| |||
24 | 26 | | |
25 | 27 | | |
26 | 28 | | |
27 | | - | |
| 29 | + | |
28 | 30 | | |
29 | 31 | | |
30 | 32 | | |
31 | 33 | | |
32 | 34 | | |
33 | 35 | | |
34 | 36 | | |
35 | | - | |
36 | 37 | | |
37 | 38 | | |
38 | 39 | | |
| |||
78 | 79 | | |
79 | 80 | | |
80 | 81 | | |
| 82 | + | |
81 | 83 | | |
82 | 84 | | |
83 | 85 | | |
| |||
256 | 258 | | |
257 | 259 | | |
258 | 260 | | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
259 | 283 | | |
260 | 284 | | |
261 | 285 | | |
| |||
484 | 508 | | |
485 | 509 | | |
486 | 510 | | |
| 511 | + | |
| 512 | + | |
| 513 | + | |
| 514 | + | |
| 515 | + | |
| 516 | + | |
| 517 | + | |
| 518 | + | |
487 | 519 | | |
488 | | - | |
| 520 | + | |
489 | 521 | | |
490 | 522 | | |
491 | 523 | | |
492 | 524 | | |
493 | 525 | | |
494 | 526 | | |
495 | | - | |
| 527 | + | |
496 | 528 | | |
497 | 529 | | |
498 | 530 | | |
| |||
503 | 535 | | |
504 | 536 | | |
505 | 537 | | |
506 | | - | |
507 | 538 | | |
508 | 539 | | |
| 540 | + | |
509 | 541 | | |
510 | 542 | | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
511 | 549 | | |
512 | 550 | | |
513 | 551 | | |
| |||
523 | 561 | | |
524 | 562 | | |
525 | 563 | | |
526 | | - | |
527 | | - | |
528 | | - | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
529 | 571 | | |
530 | 572 | | |
531 | 573 | | |
532 | | - | |
533 | | - | |
534 | | - | |
535 | | - | |
536 | | - | |
537 | | - | |
| 574 | + | |
538 | 575 | | |
539 | 576 | | |
540 | 577 | | |
| |||
571 | 608 | | |
572 | 609 | | |
573 | 610 | | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
| 624 | + | |
| 625 | + | |
| 626 | + | |
| 627 | + | |
| 628 | + | |
| 629 | + | |
| 630 | + | |
| 631 | + | |
| 632 | + | |
| 633 | + | |
| 634 | + | |
| 635 | + | |
| 636 | + | |
| 637 | + | |
| 638 | + | |
| 639 | + | |
| 640 | + | |
| 641 | + | |
| 642 | + | |
| 643 | + | |
| 644 | + | |
| 645 | + | |
| 646 | + | |
| 647 | + | |
| 648 | + | |
| 649 | + | |
| 650 | + | |
| 651 | + | |
| 652 | + | |
| 653 | + | |
| 654 | + | |
| 655 | + | |
| 656 | + | |
| 657 | + | |
| 658 | + | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
| 669 | + | |
| 670 | + | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
302 | 302 | | |
303 | 303 | | |
304 | 304 | | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
305 | 323 | | |
306 | 324 | | |
307 | 325 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
79 | 96 | | |
80 | 97 | | |
81 | 98 | | |
| |||
162 | 179 | | |
163 | 180 | | |
164 | 181 | | |
165 | | - | |
166 | | - | |
167 | | - | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
168 | 186 | | |
169 | | - | |
| 187 | + | |
170 | 188 | | |
171 | 189 | | |
172 | 190 | | |
| |||
466 | 484 | | |
467 | 485 | | |
468 | 486 | | |
469 | | - | |
| 487 | + | |
470 | 488 | | |
471 | 489 | | |
472 | 490 | | |
| |||
485 | 503 | | |
486 | 504 | | |
487 | 505 | | |
488 | | - | |
| 506 | + | |
| 507 | + | |
| 508 | + | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
489 | 512 | | |
490 | 513 | | |
491 | 514 | | |
| |||
0 commit comments