Skip to content

Commit 1ea7a8d

Browse files
authored
fix: handle skipped eval intervals when ckpt_step jumps (#1903)
* fix: handle skipped eval intervals when ckpt_step jumps + fix eval step reporting When strict_async_level=False, ckpt_step can jump over eval interval boundaries (e.g. 24->26 skipping 25), permanently missing scheduled evals. Replace exact modulo check with range check that detects crossed intervals. Also fix eval metrics being logged at progress.step instead of ckpt_step, which caused eval results to appear at the wrong step in the metrics API. * format ruff * revert step=ckpt_step back to step=progress.step for eval logging W&B requires monotonically increasing steps. Using ckpt_step (which lags behind progress.step) causes W&B to silently drop eval metrics since training already logged at a higher step. The ckpt_step is already embedded in the metrics as progress/ckpt_step for platform-side lookup. * fix stale comment about eval subprocess * add unit tests for eval scheduling range check * extract eval scheduling logic into compute_eval_ckpt_step function Move range check logic from inline in orchestrator.py to a proper function in eval_utils.py so tests import the real implementation. * use plain functions instead of class in tests
1 parent f655f71 commit 1ea7a8d

File tree

3 files changed

+140
-10
lines changed

3 files changed

+140
-10
lines changed

src/prime_rl/orchestrator/eval_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,32 @@
1313
from prime_rl.utils.utils import capitalize
1414

1515

16+
def compute_eval_ckpt_step(
17+
ckpt_step: int,
18+
prev_ckpt_step: int,
19+
last_eval_step: int,
20+
interval: int,
21+
eval_base_model: bool = True,
22+
) -> int | None:
23+
"""Determine which checkpoint step (if any) should trigger an eval.
24+
25+
Handles the case where ckpt_step jumps over interval boundaries by finding
26+
the highest interval-aligned step in (prev_ckpt_step, ckpt_step].
27+
28+
Returns the interval step to eval at, or None if no eval should run.
29+
"""
30+
if ckpt_step <= prev_ckpt_step:
31+
return None
32+
highest_interval_step = (ckpt_step // interval) * interval
33+
if highest_interval_step > prev_ckpt_step and highest_interval_step > last_eval_step:
34+
if highest_interval_step == 0:
35+
if ckpt_step == 0 and eval_base_model:
36+
return 0
37+
else:
38+
return highest_interval_step
39+
return None
40+
41+
1642
def get_eval_sampling_args(sampling_config: EvalSamplingConfig) -> dict[str, Any]:
1743
"""Get sampling args for evaluation."""
1844
# Initialize sampling args

src/prime_rl/orchestrator/orchestrator.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tomli_w
88

99
from prime_rl.orchestrator.advantage import compute_advantages
10-
from prime_rl.orchestrator.eval_utils import get_eval_sampling_args
10+
from prime_rl.orchestrator.eval_utils import compute_eval_ckpt_step, get_eval_sampling_args
1111
from prime_rl.orchestrator.event_loop_lag import EventLoopLagMonitor
1212
from prime_rl.orchestrator.patches import monkey_patch_chat_completion_logprobs, monkey_patch_oai_iterable_types
1313
from prime_rl.orchestrator.trajectories import build_vlm_image_cache, interleave_rollout
@@ -309,6 +309,8 @@ async def orchestrate(config: OrchestratorConfig):
309309

310310
# Track last online eval checkpoint step for this process
311311
last_eval_step = -1
312+
# Track previous ckpt_step to detect when ckpt_step jumps over eval interval boundaries
313+
prev_ckpt_step = -1
312314

313315
# Reset weights to base model if starting from scratch
314316
progress = Progress()
@@ -318,8 +320,12 @@ async def orchestrate(config: OrchestratorConfig):
318320
logger.info(f"Resuming training from checkpoint step {checkpoint_step}")
319321
scheduler.ckpt_step = progress.step # Always resume from the latest checkpoint
320322
if config.eval and config.eval.skip_eval_on_resume:
323+
prev_ckpt_step = scheduler.ckpt_step
321324
last_eval_step = scheduler.ckpt_step
322325
logger.info(f"Skipping online eval on resume (ckpt_step={scheduler.ckpt_step})")
326+
else:
327+
# Allow eval at resumed step by setting prev_ckpt_step one behind
328+
prev_ckpt_step = scheduler.ckpt_step - 1
323329

324330
# In NCCL mode, skip existence check - weights are broadcasted, not stored on disk
325331
check_exists = config.weight_broadcast.type != "nccl"
@@ -376,16 +382,25 @@ async def orchestrate(config: OrchestratorConfig):
376382
logger.info(f"Starting orchestrator step {progress.step}")
377383
step_start_time = time.perf_counter()
378384

379-
# Run evals BEFORE training (blocking, in subprocess to isolate event loop)
380-
# This ensures weights don't change during eval and eval doesn't cause event loop lag
381-
if (
382-
config.eval
383-
and ckpt_step % config.eval.interval == 0
384-
and ckpt_step > last_eval_step
385-
and ((ckpt_step == 0 and config.eval.eval_base_model) or ckpt_step > 0)
386-
):
385+
# Run evals BEFORE training (blocking). Weight updates are paused via
386+
# scheduler.checkpoint_ready during eval to ensure consistent weights.
387+
# Use range check to handle ckpt_step jumping over interval boundaries.
388+
eval_ckpt_step = None
389+
if config.eval:
390+
eval_ckpt_step = compute_eval_ckpt_step(
391+
ckpt_step=ckpt_step,
392+
prev_ckpt_step=prev_ckpt_step,
393+
last_eval_step=last_eval_step,
394+
interval=config.eval.interval,
395+
eval_base_model=config.eval.eval_base_model,
396+
)
397+
398+
if eval_ckpt_step is not None:
387399
last_eval_step = ckpt_step
388-
logger.info(f"Running evals for checkpoint step {ckpt_step}")
400+
if eval_ckpt_step != ckpt_step:
401+
logger.info(f"Running evals for interval step {eval_ckpt_step} (current ckpt_step={ckpt_step})")
402+
else:
403+
logger.info(f"Running evals for checkpoint step {ckpt_step}")
389404

390405
# Pause weight updates and re-scheduling of training rollouts during eval
391406
# to avoid evaluating across different checkpoints and avoid congestion
@@ -417,6 +432,9 @@ async def orchestrate(config: OrchestratorConfig):
417432
# Resume weight updates
418433
scheduler.checkpoint_ready.set()
419434

435+
# Update prev_ckpt_step for next iteration
436+
prev_ckpt_step = ckpt_step
437+
420438
# Schedule generating the training batch
421439
temperature = compute_temperature(progress.step, config.sampling, config.max_steps)
422440
sampling_args = get_sampling_args(config.sampling, temperature=temperature)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Unit tests for eval scheduling logic - specifically the range check
2+
that detects when ckpt_step jumps over eval interval boundaries."""
3+
4+
from prime_rl.orchestrator.eval_utils import compute_eval_ckpt_step
5+
6+
7+
def test_exact_hit():
8+
result = compute_eval_ckpt_step(ckpt_step=25, prev_ckpt_step=24, last_eval_step=0, interval=25)
9+
assert result == 25
10+
11+
12+
def test_jump_over_interval():
13+
result = compute_eval_ckpt_step(ckpt_step=26, prev_ckpt_step=24, last_eval_step=0, interval=25)
14+
assert result == 25
15+
16+
17+
def test_no_interval_crossed():
18+
result = compute_eval_ckpt_step(ckpt_step=23, prev_ckpt_step=22, last_eval_step=0, interval=25)
19+
assert result is None
20+
21+
22+
def test_base_model_eval_at_step_0():
23+
result = compute_eval_ckpt_step(
24+
ckpt_step=0, prev_ckpt_step=-1, last_eval_step=-1, interval=25, eval_base_model=True
25+
)
26+
assert result == 0
27+
28+
29+
def test_base_model_eval_disabled():
30+
result = compute_eval_ckpt_step(
31+
ckpt_step=0, prev_ckpt_step=-1, last_eval_step=-1, interval=25, eval_base_model=False
32+
)
33+
assert result is None
34+
35+
36+
def test_no_double_eval():
37+
result = compute_eval_ckpt_step(ckpt_step=25, prev_ckpt_step=24, last_eval_step=25, interval=25)
38+
assert result is None
39+
40+
41+
def test_no_change_in_ckpt_step():
42+
result = compute_eval_ckpt_step(ckpt_step=25, prev_ckpt_step=25, last_eval_step=0, interval=25)
43+
assert result is None
44+
45+
46+
def test_multiple_intervals_crossed():
47+
result = compute_eval_ckpt_step(ckpt_step=76, prev_ckpt_step=24, last_eval_step=0, interval=25)
48+
assert result == 75
49+
50+
51+
def test_second_interval():
52+
result = compute_eval_ckpt_step(ckpt_step=50, prev_ckpt_step=49, last_eval_step=25, interval=25)
53+
assert result == 50
54+
55+
56+
def test_jump_across_second_interval():
57+
result = compute_eval_ckpt_step(ckpt_step=51, prev_ckpt_step=48, last_eval_step=25, interval=25)
58+
assert result == 50
59+
60+
61+
def test_production_scenario_step25_skipped():
62+
"""Reproduces the bug from run c14miuyha2yhxkw1z3eqgyub."""
63+
result = compute_eval_ckpt_step(ckpt_step=26, prev_ckpt_step=24, last_eval_step=0, interval=25)
64+
assert result == 25
65+
66+
67+
def test_production_scenario_step50_exact():
68+
result = compute_eval_ckpt_step(ckpt_step=50, prev_ckpt_step=49, last_eval_step=26, interval=25)
69+
assert result == 50
70+
71+
72+
def test_simulate_full_run():
73+
ckpt_steps = [0, 0, 3, 5, 10, 15, 20, 24, 26, 30, 35, 40, 48, 51, 60, 70, 74, 76]
74+
interval = 25
75+
last_eval_step = -1
76+
prev_ckpt_step = -1
77+
eval_triggered_at = []
78+
79+
for ckpt_step in ckpt_steps:
80+
result = compute_eval_ckpt_step(ckpt_step, prev_ckpt_step, last_eval_step, interval)
81+
if result is not None:
82+
eval_triggered_at.append(result)
83+
last_eval_step = ckpt_step
84+
prev_ckpt_step = ckpt_step
85+
86+
assert eval_triggered_at == [0, 25, 50, 75]

0 commit comments

Comments
 (0)