Commit 9a70981
Refactor ForwardKLEvaluator to compute IS accuracy and ESS metrics
- Replace forward KL with importance-weighted accuracy and effective sample size
- Shard by problem_id hash (not trace index) so each rank gets complete problems
- Add TraceTensors dataclass with smart constructors (empty, from_traces)
- Vectorize log prob computation using F.cross_entropy with completion mask
- Add _scatter_logsumexp for numerically stable grouped reductions
- Use allreduce_scalar for cleaner distributed reduction
- Pre-tensorize all trace data for efficient batch slicing
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>1 parent 8e6657e commit 9a70981
1 file changed
+255
-154
lines changed
0 commit comments