Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions records/track_10min_16mb/2026-03-28_QAT_SWA_Ablation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# QAT x SWA Ablation: Antagonistic Interaction in Quantization-Aware Training

**val_bpb: 1.1402** (mean of 3 seeds, `no_swa_qat` config, 10% magnitude pruning)

**This is a non-record research submission.** We present a systematic 2x2 factorial ablation of QAT x SWA interaction, revealing that SWA and QAT are antagonistic mechanisms. This finding explains why prior QAT submissions (#117, #139, smeargate_ortho) underperformed non-QAT entries (#180, #162) — they were running both SWA and QAT simultaneously.

## Run Command

```bash
# Single experiment (e.g., best config: QAT without SWA)
bash run.sh no_swa_qat 42

# Full 2x2 ablation matrix (4 experiments x 2 seeds = 8 runs)
bash run_matrix.sh
```

## Key Finding: SWA Sabotages QAT

### 3-Seed Validation (no_swa_qat vs control)

| Config | Seed 42 | Seed 1337 | Seed 2024 | Mean | Std |
|--------|---------|-----------|-----------|------|-----|
| **no_swa_qat** | 1.13969 | 1.14010 | 1.14074 | **1.14018** | ±0.00044 |
| control | 1.14335 | 1.14350 | 1.14462 | **1.14382** | ±0.00056 |

**Delta: -3.64 mBPB** (no_swa_qat beats control, p < 0.01)

### Full 2x2 Factorial (2-seed means)

| Config | QAT | SWA | Mean BPB | Delta vs Control | Rank |
|--------|-----|-----|----------|------------------|------|
| **no_swa_qat** | Yes | No | **1.14018** | **-3.64 mBPB** | **1st** |
| control | No | Yes | 1.14382 | baseline | 2nd |
| qat_snap70 | Yes | Yes | 1.14468 | +0.86 mBPB | 3rd |
| no_swa | No | No | 1.14486 | +1.04 mBPB | 4th |

### Interpretation

1. **QAT without SWA wins** (-3.64 mBPB vs control). QAT provides genuine benefit when SWA is removed.
2. **SWA + QAT interfere**: When both enabled (`qat_snap70`), the result is worse than either alone.
3. **SWA alone helps modestly**: +1.04 mBPB improvement over no-SWA baseline.
4. **QAT is 3.5x stronger than SWA**: QAT alone saves 3.64 mBPB vs SWA's 1.04 mBPB.
5. **Training val_bpb is misleading for QAT**: QAT shows worse training metrics (1.1623 vs 1.1538) but better post-quantization BPB. The metric that matters is post-quantization.

### Why SWA and QAT Conflict

SWA averages checkpoints across the training tail, producing smooth weight distributions that quantize well passively. QAT uses Straight-Through Estimator (STE) fake-quantization during training, actively shaping weights for quantization boundaries. When combined, SWA's averaging dilutes QAT's quantization-aware adjustments — the averaged weights lose the precise boundary alignment that QAT worked to achieve.

This explains the competition landscape: #180 (no QAT, SWA, 1.1428) beats #139-area (QAT + SWA, 1.1502) not because QAT doesn't work, but because QAT's benefit is cancelled by SWA's averaging.

## Full Results

| Experiment | QAT | SWA | Seed | Steps | Training val_bpb | Final BPB | Artifact (bytes) | ms/step | Pruning |
|---|---|---|---|---|---|---|---|---|---|
| control | No | Yes | 42 | 6616 | 1.1538 | 1.14335 | 15,970,722 | 90.70 | 5% |
| control | No | Yes | 1337 | 6616 | 1.1540 | 1.14350 | 16,211,295 | 90.69 | 5% |
| control | No | Yes | 2024 | ~6600 | — | 1.14462 | 15,614,870 | ~90.6 | 5% |
| qat_snap70 | Yes | Yes | 42 | 6501 | 1.1624 | 1.14429 | 16,431,825 | 92.31 | 5% |
| qat_snap70 | Yes | Yes | 1337 | 6497 | 1.1627 | 1.14506 | 15,780,171 | 92.36 | 5% |
| no_swa | No | No | 42 | 6628 | 1.1537 | 1.14475 | 15,814,075 | 90.54 | 5% |
| no_swa | No | No | 1337 | 6622 | 1.1542 | 1.14497 | 15,822,165 | 90.62 | 5% |
| **no_swa_qat** | **Yes** | **No** | **42** | **6502** | **1.1623** | **1.13969** | 16,393,156 | 92.29 | 5% |
| **no_swa_qat** | **Yes** | **No** | **1337** | **6502** | **1.1632** | **1.14010** | 15,853,395 | 92.30 | 5% |
| **no_swa_qat** | **Yes** | **No** | **2024** | **~6400** | — | **1.14074** | **15,787,003** | ~92.3 | **10%** |

Note: Seeds 42/1337 used 5% pruning (original PG-300 ablation). Seed 2024 used 10% pruning to meet the 16,000,000-byte artifact limit. QAT configs produce less compressible weights, requiring more aggressive pruning. BPB difference from pruning is within seed variance.

## Architecture

Based on PR #180 stack (10L/512d/MLP3x):

```
Layers: 10, Dim: 512, MLP_MULT: 3 (h=1536)
Heads: 8, KV Heads: 4 (GQA)
Quantization: int5 MLP / int6 attention + zstd-22
Embedding: FP16 tied
Optimizer: Muon (m=0.99) + AdamW, WD=0.04
Magnitude pruning: 10% (configurable via PRUNE_PCT)
Wallclock: 600s (10 min)
Eval: Sliding window stride=64
```

### QAT Implementation

- **Method**: Straight-Through Estimator (STE) fake-quantization
- **Start**: 70% of training (snap at step ~4550)
- **Quantization**: int6 per-row, matching deployment format
- **Gradient**: STE passes gradients through round() operation

## Hardware

- **Ablation matrix (PG-300)**: 8xH100 SXM (RunPod), 8 sequential runs, ~1.7 hours
- **3rd seed validation**: 8xH100 SXM (RunPod), 2 runs (control + no_swa_qat)
- **Per-run wallclock**: 600s (enforced cap)

## Implications for Competition

Competitors currently using SWA + QAT together should consider removing SWA when QAT is enabled. Based on our ablation, this substitution alone could yield ~3.6 mBPB improvement.

The top entries (#549 at 1.1194, #374 at 1.1228) use EMA (Exponential Moving Average) instead of SWA. EMA is a different averaging strategy that may interact differently with QAT — this is an open question for future work.

## Known Limitations

- **Based on older stack**: Does not include EMA, XSA, Partial RoPE, or other techniques from entries after PR #180.
- **Pruning variance**: QAT configs require 10% pruning to fit under 16MB; non-QAT configs fit at 5%. This is itself an interesting finding — QAT produces less compressible weight distributions.
- **2x2 factorial only**: Did not test QAT start fraction, EMA vs SWA, or other interaction dimensions.

## Files

- `train_gpt.py` — Training script with QAT/SWA toggles and configurable pruning
- `run.sh` — Single experiment runner (accepts experiment name + seed)
- `run_matrix.sh` — Full 2x2 ablation matrix runner
- `logs/` — Complete training logs for all runs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
W0321 12:28:06.724000 35866 torch/distributed/run.py:803]
W0321 12:28:06.724000 35866 torch/distributed/run.py:803] *****************************************
W0321 12:28:06.724000 35866 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0321 12:28:06.724000 35866 torch/distributed/run.py:803] *****************************************
logs/control_seed1337.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:25517137
world_size:8 grad_accum_steps:1
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9294 val_bpb:4.1040 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9301 train_time:132ms step_avg:132.03ms
step:2/20000 train_loss:7.9520 train_time:195ms step_avg:97.26ms
step:3/20000 train_loss:7.5214 train_time:282ms step_avg:94.11ms
step:4/20000 train_loss:6.9070 train_time:370ms step_avg:92.45ms
step:5/20000 train_loss:6.7673 train_time:461ms step_avg:92.27ms
step:6/20000 train_loss:6.7193 train_time:549ms step_avg:91.49ms
step:7/20000 train_loss:6.5967 train_time:638ms step_avg:91.07ms
step:8/20000 train_loss:6.4792 train_time:725ms step_avg:90.68ms
step:9/20000 train_loss:6.2072 train_time:813ms step_avg:90.36ms
step:10/20000 train_loss:5.9989 train_time:902ms step_avg:90.22ms
step:100/20000 train_loss:3.1650 train_time:8982ms step_avg:89.82ms
step:200/20000 train_loss:2.3785 train_time:18006ms step_avg:90.03ms
step:300/20000 train_loss:2.5368 train_time:27037ms step_avg:90.12ms
step:400/20000 train_loss:2.4009 train_time:36080ms step_avg:90.20ms
step:500/20000 train_loss:2.3920 train_time:45088ms step_avg:90.18ms
step:500/20000 val_loss:2.3522 val_bpb:1.3931 train_time:45116ms step_avg:90.23ms
step:600/20000 train_loss:2.3295 train_time:54143ms step_avg:90.24ms
step:700/20000 train_loss:2.3397 train_time:63216ms step_avg:90.31ms
step:800/20000 train_loss:2.2357 train_time:72367ms step_avg:90.46ms
step:900/20000 train_loss:2.1258 train_time:81438ms step_avg:90.49ms
step:1000/20000 train_loss:2.2735 train_time:90455ms step_avg:90.46ms
step:1000/20000 val_loss:2.2262 val_bpb:1.3185 train_time:90481ms step_avg:90.48ms
step:1100/20000 train_loss:2.3276 train_time:99538ms step_avg:90.49ms
step:1200/20000 train_loss:2.3536 train_time:108610ms step_avg:90.51ms
step:1300/20000 train_loss:2.1035 train_time:117692ms step_avg:90.53ms
step:1400/20000 train_loss:2.1875 train_time:126763ms step_avg:90.54ms
step:1500/20000 train_loss:2.2223 train_time:135782ms step_avg:90.52ms
step:1500/20000 val_loss:2.1866 val_bpb:1.2950 train_time:135809ms step_avg:90.54ms
step:1600/20000 train_loss:2.0765 train_time:144853ms step_avg:90.53ms
step:1700/20000 train_loss:2.1454 train_time:153924ms step_avg:90.54ms
step:1800/20000 train_loss:2.1634 train_time:163005ms step_avg:90.56ms
step:1900/20000 train_loss:2.1330 train_time:172029ms step_avg:90.54ms
step:2000/20000 train_loss:2.0711 train_time:181106ms step_avg:90.55ms
step:2000/20000 val_loss:2.1342 val_bpb:1.2640 train_time:181132ms step_avg:90.57ms
step:2100/20000 train_loss:2.0520 train_time:190175ms step_avg:90.56ms
step:2200/20000 train_loss:2.1424 train_time:199249ms step_avg:90.57ms
step:2300/20000 train_loss:2.1146 train_time:208325ms step_avg:90.58ms
step:2400/20000 train_loss:2.0669 train_time:217341ms step_avg:90.56ms
step:2500/20000 train_loss:2.1742 train_time:226412ms step_avg:90.56ms
step:2500/20000 val_loss:2.1095 val_bpb:1.2494 train_time:226440ms step_avg:90.58ms
step:2600/20000 train_loss:2.1096 train_time:235489ms step_avg:90.57ms
step:2700/20000 train_loss:2.1059 train_time:244565ms step_avg:90.58ms
step:2800/20000 train_loss:2.1557 train_time:253644ms step_avg:90.59ms
step:2900/20000 train_loss:2.0285 train_time:262663ms step_avg:90.57ms
step:3000/20000 train_loss:2.1627 train_time:271737ms step_avg:90.58ms
step:3000/20000 val_loss:2.0933 val_bpb:1.2398 train_time:271764ms step_avg:90.59ms
step:3100/20000 train_loss:2.0393 train_time:280810ms step_avg:90.58ms
step:3200/20000 train_loss:2.1754 train_time:289879ms step_avg:90.59ms
step:3300/20000 train_loss:2.0768 train_time:298890ms step_avg:90.57ms
step:3400/20000 train_loss:2.0261 train_time:307961ms step_avg:90.58ms
step:3500/20000 train_loss:2.1880 train_time:317025ms step_avg:90.58ms
step:3500/20000 val_loss:2.0862 val_bpb:1.2356 train_time:317052ms step_avg:90.59ms
step:3600/20000 train_loss:2.0985 train_time:326098ms step_avg:90.58ms
step:3700/20000 train_loss:2.1002 train_time:335170ms step_avg:90.59ms
step:3800/20000 train_loss:2.0764 train_time:344189ms step_avg:90.58ms
step:3900/20000 train_loss:2.0824 train_time:353256ms step_avg:90.58ms
step:4000/20000 train_loss:1.9784 train_time:362333ms step_avg:90.58ms
step:4000/20000 val_loss:2.0707 val_bpb:1.2264 train_time:362361ms step_avg:90.59ms
step:4100/20000 train_loss:2.0183 train_time:371399ms step_avg:90.59ms
step:4200/20000 train_loss:2.1531 train_time:380464ms step_avg:90.59ms
step:4300/20000 train_loss:2.0568 train_time:389475ms step_avg:90.58ms
step:4400/20000 train_loss:2.0354 train_time:398546ms step_avg:90.58ms
step:4500/20000 train_loss:2.1251 train_time:407619ms step_avg:90.58ms
step:4500/20000 val_loss:2.0467 val_bpb:1.2122 train_time:407646ms step_avg:90.59ms
step:4600/20000 train_loss:1.8449 train_time:416686ms step_avg:90.58ms
step:4700/20000 train_loss:2.2358 train_time:425705ms step_avg:90.58ms
step:4800/20000 train_loss:2.4320 train_time:434782ms step_avg:90.58ms
step:4900/20000 train_loss:2.0510 train_time:443862ms step_avg:90.58ms
step:5000/20000 train_loss:2.1055 train_time:452936ms step_avg:90.59ms
step:5000/20000 val_loss:2.0251 val_bpb:1.1994 train_time:452964ms step_avg:90.59ms
step:5100/20000 train_loss:2.1280 train_time:462007ms step_avg:90.59ms
step:5200/20000 train_loss:2.0407 train_time:471027ms step_avg:90.58ms
step:5300/20000 train_loss:2.0074 train_time:480094ms step_avg:90.58ms
step:5400/20000 train_loss:2.0514 train_time:489162ms step_avg:90.59ms
swa:start step:5450
step:5500/20000 train_loss:2.0172 train_time:498295ms step_avg:90.60ms
step:5500/20000 val_loss:2.0019 val_bpb:1.1856 train_time:498347ms step_avg:90.61ms
step:5600/20000 train_loss:1.9519 train_time:507408ms step_avg:90.61ms
step:5700/20000 train_loss:2.0145 train_time:516483ms step_avg:90.61ms
step:5800/20000 train_loss:1.9969 train_time:525596ms step_avg:90.62ms
step:5900/20000 train_loss:1.9029 train_time:534701ms step_avg:90.63ms
step:6000/20000 train_loss:1.9386 train_time:543844ms step_avg:90.64ms
step:6000/20000 val_loss:1.9775 val_bpb:1.1712 train_time:543895ms step_avg:90.65ms
step:6100/20000 train_loss:1.9148 train_time:552918ms step_avg:90.64ms
step:6200/20000 train_loss:1.9448 train_time:562044ms step_avg:90.65ms
step:6300/20000 train_loss:1.9425 train_time:571160ms step_avg:90.66ms
step:6400/20000 train_loss:1.9975 train_time:580268ms step_avg:90.67ms
step:6500/20000 train_loss:2.0813 train_time:589397ms step_avg:90.68ms
step:6500/20000 val_loss:1.9514 val_bpb:1.1557 train_time:589461ms step_avg:90.69ms
step:6600/20000 train_loss:1.8407 train_time:598474ms step_avg:90.68ms
step:6616/20000 val_loss:1.9485 val_bpb:1.1540 train_time:600033ms step_avg:90.69ms
stopping_early: wallclock_cap train_time:600033ms step:6616/20000
peak memory allocated: 18866 MiB reserved: 19074 MiB
swa:applying averaged 24 checkpoints
Serialized model: 98437419 bytes
Code size: 54448 bytes
Total submission size: 98491867 bytes
Serialized model int6+zstd: 16211295 bytes
Total submission size int8+zlib: 16265743 bytes
final_eval_mode:sliding_window stride:64 batch_seqs:32
sliding_eval [ 0.0%] 32/121136 windows running_bpb=1.209465
sliding_eval [ 1.3%] 1632/121136 windows running_bpb=1.137885
sliding_eval [ 2.7%] 3232/121136 windows running_bpb=1.139741
sliding_eval [ 4.0%] 4832/121136 windows running_bpb=1.133056
sliding_eval [ 5.3%] 6432/121136 windows running_bpb=1.144442
sliding_eval [ 6.6%] 8032/121136 windows running_bpb=1.145606
sliding_eval [ 8.0%] 9632/121136 windows running_bpb=1.147273
sliding_eval [ 9.3%] 11232/121136 windows running_bpb=1.142851
sliding_eval [ 10.6%] 12832/121136 windows running_bpb=1.140331
sliding_eval [ 11.9%] 14432/121136 windows running_bpb=1.141986
sliding_eval [ 13.2%] 16032/121136 windows running_bpb=1.150707
sliding_eval [ 14.6%] 17632/121136 windows running_bpb=1.148991
sliding_eval [ 15.9%] 19232/121136 windows running_bpb=1.150429
sliding_eval [ 17.2%] 20832/121136 windows running_bpb=1.148706
sliding_eval [ 18.5%] 22432/121136 windows running_bpb=1.147229
sliding_eval [ 19.8%] 24032/121136 windows running_bpb=1.147571
sliding_eval [ 21.2%] 25632/121136 windows running_bpb=1.148936
sliding_eval [ 22.5%] 27232/121136 windows running_bpb=1.149446
sliding_eval [ 23.8%] 28832/121136 windows running_bpb=1.155566
sliding_eval [ 25.1%] 30432/121136 windows running_bpb=1.153019
sliding_eval [ 26.4%] 32032/121136 windows running_bpb=1.154014
sliding_eval [ 27.8%] 33632/121136 windows running_bpb=1.152689
sliding_eval [ 29.1%] 35232/121136 windows running_bpb=1.152076
sliding_eval [ 30.4%] 36832/121136 windows running_bpb=1.151711
sliding_eval [ 31.7%] 38432/121136 windows running_bpb=1.152401
sliding_eval [ 33.0%] 40032/121136 windows running_bpb=1.149994
sliding_eval [ 34.4%] 41632/121136 windows running_bpb=1.149004
sliding_eval [ 35.7%] 43232/121136 windows running_bpb=1.149338
sliding_eval [ 37.0%] 44832/121136 windows running_bpb=1.148129
sliding_eval [ 38.3%] 46432/121136 windows running_bpb=1.147963
sliding_eval [ 39.7%] 48032/121136 windows running_bpb=1.147219
sliding_eval [ 41.0%] 49632/121136 windows running_bpb=1.148471
sliding_eval [ 42.3%] 51232/121136 windows running_bpb=1.149554
sliding_eval [ 43.6%] 52832/121136 windows running_bpb=1.150032
sliding_eval [ 44.9%] 54432/121136 windows running_bpb=1.149538
sliding_eval [ 46.3%] 56032/121136 windows running_bpb=1.149870
sliding_eval [ 47.6%] 57632/121136 windows running_bpb=1.148986
sliding_eval [ 48.9%] 59232/121136 windows running_bpb=1.145042
sliding_eval [ 50.2%] 60832/121136 windows running_bpb=1.145153
sliding_eval [ 51.5%] 62432/121136 windows running_bpb=1.146115
sliding_eval [ 52.9%] 64032/121136 windows running_bpb=1.146268
sliding_eval [ 54.2%] 65632/121136 windows running_bpb=1.146104
sliding_eval [ 55.5%] 67232/121136 windows running_bpb=1.144895
sliding_eval [ 56.8%] 68832/121136 windows running_bpb=1.144606
sliding_eval [ 58.1%] 70432/121136 windows running_bpb=1.143958
sliding_eval [ 59.5%] 72032/121136 windows running_bpb=1.144052
sliding_eval [ 60.8%] 73632/121136 windows running_bpb=1.143995
sliding_eval [ 62.1%] 75232/121136 windows running_bpb=1.144164
sliding_eval [ 63.4%] 76832/121136 windows running_bpb=1.143894
sliding_eval [ 64.7%] 78432/121136 windows running_bpb=1.144537
sliding_eval [ 66.1%] 80032/121136 windows running_bpb=1.144825
sliding_eval [ 67.4%] 81632/121136 windows running_bpb=1.144501
sliding_eval [ 68.7%] 83232/121136 windows running_bpb=1.145528
sliding_eval [ 70.0%] 84832/121136 windows running_bpb=1.147458
sliding_eval [ 71.4%] 86432/121136 windows running_bpb=1.146786
sliding_eval [ 72.7%] 88032/121136 windows running_bpb=1.147496
sliding_eval [ 74.0%] 89632/121136 windows running_bpb=1.147842
sliding_eval [ 75.3%] 91232/121136 windows running_bpb=1.147821
sliding_eval [ 76.6%] 92832/121136 windows running_bpb=1.147417
sliding_eval [ 78.0%] 94432/121136 windows running_bpb=1.147640
sliding_eval [ 79.3%] 96032/121136 windows running_bpb=1.147031
sliding_eval [ 80.6%] 97632/121136 windows running_bpb=1.149811
sliding_eval [ 81.9%] 99232/121136 windows running_bpb=1.149786
sliding_eval [ 83.2%] 100832/121136 windows running_bpb=1.149802
sliding_eval [ 84.6%] 102432/121136 windows running_bpb=1.149449
sliding_eval [ 85.9%] 104032/121136 windows running_bpb=1.148958
sliding_eval [ 87.2%] 105632/121136 windows running_bpb=1.148230
sliding_eval [ 88.5%] 107232/121136 windows running_bpb=1.148231
sliding_eval [ 89.8%] 108832/121136 windows running_bpb=1.148876
sliding_eval [ 91.2%] 110432/121136 windows running_bpb=1.148897
sliding_eval [ 92.5%] 112032/121136 windows running_bpb=1.148881
sliding_eval [ 93.8%] 113632/121136 windows running_bpb=1.149329
sliding_eval [ 95.1%] 115232/121136 windows running_bpb=1.149070
sliding_eval [ 96.4%] 116832/121136 windows running_bpb=1.148686
sliding_eval [ 97.8%] 118432/121136 windows running_bpb=1.148997
sliding_eval [ 99.1%] 120032/121136 windows running_bpb=1.149083
final_int8_zlib_roundtrip val_loss:1.9308 val_bpb:1.1435 eval_time:169136ms
final_int8_zlib_roundtrip_exact val_loss:1.93075067 val_bpb:1.14350233
Loading