-
Notifications
You must be signed in to change notification settings - Fork 39
Policy inference escalation in action-chunking: problem analysis and proposed Isaac Lab API #522
Description
Summary
Action-chunking policies (e.g., GR00T N1) output a chunk of 16–32 actions per inference call, replayed over subsequent env.step() calls. With chunk_length=16, inference should run once every 16 env steps. However, when environments terminate asynchronously and their chunk consumption rhythms desynchronize, the current any()-based trigger causes inference to fire on nearly every env step — each costing ~200 ms for GR1 or ~700 ms for DROID. With enough parallel environments, this effectively makes every env.step() pay the full policy inference cost, negating the benefit of action chunking entirely.
This document describes the problem, our current workaround, its limitations due to Isaac Lab's all-envs-at-once step() design, and a proposed API change to Isaac Lab that would enable a clean solution.
1. The Problem: Inference Frequency Degradation
1.1 Root Cause
In ActionChunkingState.get_action() (action_chunking.py):
if self.env_requires_new_chunk.any(): # ANY env needs chunk → full-batch inference
new_chunk = fetch_chunk_fn() # runs model on ALL envs
self.current_action_chunk[mask] = new_chunk[mask] # only updates requesting envsWhen an environment terminates and resets, it sets env_requires_new_chunk = True. Because any() triggers inference for the entire batch whenever any single env needs a new chunk, one reset is enough to desynchronize all environments.
1.2 Degradation Flow
Initial state: all envs synchronized
┌──────────────────────────────────────────────┐
│ env0 env1 env2 env3 env4 env5 ... │
│ idx=0 idx=0 idx=0 idx=0 idx=0 idx=0 │
│ Inference every 16 steps (~14 step/s) │
└──────────────────────────────────────────────┘
↓ env2 terminates at step 5
┌──────────────────────────────────────────────┐
│ env0 env1 env2 env3 env4 env5 ... │
│ idx=5 idx=5 RESET idx=5 idx=5 idx=5 │
│ env2 needs new chunk → any() = True │
│ → full-batch inference for ALL envs │
└──────────────────────────────────────────────┘
↓ more envs terminate at different times
┌──────────────────────────────────────────────┐
│ Fully desynchronized: chunk indices diverge │
│ Almost every step, at least one env needs │
│ a new chunk → inference on nearly every step │
│ With enough envs, every env.step() pays the │
│ full ~200ms inference cost │
│ NEVER recovers │
└──────────────────────────────────────────────┘
2. Our Workaround: Fixed-Frequency Global Clock
We replace the per-env any() trigger with global_step % chunk_length == 0. Environments that terminate mid-chunk enter a waiting state with zero actions until the next inference tick.
This works, but creates a new problem: what happens to waiting environments?
3. Why Isaac Lab Makes This Hard
3.1 The Waiting Period Problem
When an env terminates at step 5 of a 16-step chunk, it waits up to 15 steps for the next inference tick. During this waiting period, we'd ideally freeze the env entirely — no physics, no rendering, no recording. But Isaac Lab's ManagerBasedRLEnv.step() does not support per-env masking.
In manager_based_rl_env.py:153-241:
def step(self, action):
# ALL operations apply to ALL envs — no per-env skip possible
self.action_manager.process_action(action) # ALL envs (line 173)
self.recorder_manager.record_pre_step() # ALL envs (line 175)
for _ in range(self.cfg.decimation):
self.action_manager.apply_action() # ALL envs (line 185)
self.scene.write_data_to_sim() # ALL envs (line 187)
self.sim.step() # ALL envs
self.recorder_manager.record_post_physics_decimation_step() # ALL envs (line 190)
self.reset_buf = self.termination_manager.compute() # ALL envs (line 204)
self.recorder_manager.record_post_step() # ALL envs (line 213)
if len(reset_env_ids) > 0:
self.recorder_manager.record_pre_reset(env_ids) # saves episode to HDF5 (line 219)
self._reset_idx(env_ids) # resets scene (line 221)
self.recorder_manager.record_post_reset(env_ids) # starts new episode (line 229)Every operation — action application, physics stepping, rendering, recording, termination checking — applies to all environments simultaneously. There is no way to:
- Skip
apply_action()for specific envs — zero action still applies PD targets, causing minor drift - Skip
record_pre_step()/record_post_step()— the recorder accumulates junk data for waiting envs; if a waiting env drifts enough to re-terminate, a junk episode gets saved to HDF5 - Skip physics for specific envs —
sim.step()advances all envs - Suppress termination for specific envs —
termination_manager.compute()checks all envs
3.2 Consequences
| Concern | Impact |
|---|---|
| Zero actions cause drift | PD controller targets ≈ default pose. Minor but nonzero drift over ~15 steps (~0.3s sim time). |
| Recorder pollution | record_pre_step() / record_post_step() record junk data for waiting envs. If re-termination occurs, junk episodes appear in HDF5. |
| Metrics contamination | Must filter junk episodes in the rollout loop; HDF5 data may still contain them. |
4. Proposed Isaac Lab API: Per-Env Suspend/Resume
The cleanest solution requires Isaac Lab to support per-env masking in step():
class ManagerBasedRLEnv:
def suspend_envs(self, env_ids: torch.Tensor):
"""Mark envs as suspended. step() skips physics/rendering/recording for them."""
self._suspended_mask[env_ids] = True
def resume_envs(self, env_ids: torch.Tensor):
"""Resume suspended envs."""
self._suspended_mask[env_ids] = False
def step(self, action):
active_mask = ~self._suspended_mask
# Only process active envs
self.action_manager.process_action(action, active_mask)
self.recorder_manager.record_pre_step(active_mask)
for _ in range(self.cfg.decimation):
self.action_manager.apply_action(active_mask)
self.scene.write_data_to_sim(active_mask)
self.sim.step() # physics engine handles masking internally
...
# Only check termination for active envs
self.reset_buf[self._suspended_mask] = 0
...This would let action-chunking policies:
- Suspend waiting envs immediately after termination
- Resume them at the next inference tick with a fresh action chunk
- Avoid zero-action drift, junk recorder data, and wasted physics/rendering computation
Why This Matters Beyond Action Chunking
Per-env suspend/resume is useful for any scenario where different environments progress at different rates:
- Action chunking with asynchronous termination (this case)
- Variable-length episodes with expensive resets
- Multi-agent settings where agents operate on different timescales
- Curriculum learning where some envs need fewer steps