Skip to content

Commit b561f2f

Browse files
committed
init
1 parent ae50733 commit b561f2f

File tree

3 files changed

+231
-1
lines changed

3 files changed

+231
-1
lines changed

test/test_rb.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4152,6 +4152,124 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch):
41524152
assert_allclose_td(rb_test[:], rb[:])
41534153
assert rb.writer._cursor == rb_test._writer._cursor
41544154

4155+
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
4156+
def test_incremental_checkpointing(self, storage_type, tmpdir):
4157+
"""Test incremental checkpointing saves only changed data."""
4158+
from torchrl.data.replay_buffers.checkpointers import TensorStorageCheckpointer
4159+
4160+
torch.manual_seed(0)
4161+
buffer_size = 100
4162+
batch_size = 20
4163+
4164+
# Create buffer with incremental checkpointing enabled
4165+
rb = ReplayBuffer(
4166+
storage=storage_type(buffer_size),
4167+
batch_size=batch_size,
4168+
)
4169+
rb.storage.checkpointer = TensorStorageCheckpointer(incremental=True)
4170+
4171+
# Create a second buffer to verify loads work correctly
4172+
rb_test = ReplayBuffer(
4173+
storage=storage_type(buffer_size),
4174+
batch_size=batch_size,
4175+
)
4176+
rb_test.storage.checkpointer = TensorStorageCheckpointer(incremental=True)
4177+
4178+
checkpoint_path = Path(tmpdir) / "checkpoint"
4179+
4180+
# Add first batch and checkpoint
4181+
data1 = TensorDict(
4182+
{
4183+
"obs": torch.randn(batch_size, 4),
4184+
"action": torch.randint(0, 2, (batch_size,)),
4185+
},
4186+
batch_size=[batch_size],
4187+
)
4188+
rb.extend(data1)
4189+
rb.dumps(checkpoint_path)
4190+
4191+
# Verify checkpoint cursor was set
4192+
assert rb.storage._last_checkpoint_cursor is not None
4193+
first_cursor = rb.storage._last_checkpoint_cursor
4194+
4195+
# Load and verify
4196+
rb_test.loads(checkpoint_path)
4197+
assert_allclose_td(rb_test[:], rb[:])
4198+
assert rb_test.storage._last_checkpoint_cursor == first_cursor
4199+
4200+
# Add second batch and checkpoint (should be incremental)
4201+
data2 = TensorDict(
4202+
{
4203+
"obs": torch.randn(batch_size, 4),
4204+
"action": torch.randint(0, 2, (batch_size,)),
4205+
},
4206+
batch_size=[batch_size],
4207+
)
4208+
rb.extend(data2)
4209+
rb.dumps(checkpoint_path)
4210+
4211+
# Verify cursor advanced
4212+
assert rb.storage._last_checkpoint_cursor > first_cursor
4213+
4214+
# Load and verify
4215+
rb_test.loads(checkpoint_path)
4216+
assert_allclose_td(rb_test[:], rb[:])
4217+
4218+
# Add more data until buffer wraps around
4219+
for _ in range(5):
4220+
data = TensorDict(
4221+
{
4222+
"obs": torch.randn(batch_size, 4),
4223+
"action": torch.randint(0, 2, (batch_size,)),
4224+
},
4225+
batch_size=[batch_size],
4226+
)
4227+
rb.extend(data)
4228+
4229+
# Checkpoint after wrap-around (should do full save)
4230+
rb.dumps(checkpoint_path)
4231+
4232+
# Load and verify
4233+
rb_test.loads(checkpoint_path)
4234+
assert_allclose_td(rb_test[:], rb[:])
4235+
4236+
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
4237+
def test_incremental_checkpointing_no_changes(self, storage_type, tmpdir):
4238+
"""Test incremental checkpoint when no data changed."""
4239+
from torchrl.data.replay_buffers.checkpointers import TensorStorageCheckpointer
4240+
4241+
torch.manual_seed(0)
4242+
buffer_size = 50
4243+
batch_size = 10
4244+
4245+
rb = ReplayBuffer(
4246+
storage=storage_type(buffer_size),
4247+
batch_size=batch_size,
4248+
)
4249+
rb.storage.checkpointer = TensorStorageCheckpointer(incremental=True)
4250+
4251+
checkpoint_path = Path(tmpdir) / "checkpoint"
4252+
4253+
# Add data and checkpoint
4254+
data = TensorDict(
4255+
{"obs": torch.randn(batch_size, 4)},
4256+
batch_size=[batch_size],
4257+
)
4258+
rb.extend(data)
4259+
rb.dumps(checkpoint_path)
4260+
4261+
# Checkpoint again without adding data
4262+
rb.dumps(checkpoint_path)
4263+
4264+
# Load and verify
4265+
rb_test = ReplayBuffer(
4266+
storage=storage_type(buffer_size),
4267+
batch_size=batch_size,
4268+
)
4269+
rb_test.storage.checkpointer = TensorStorageCheckpointer(incremental=True)
4270+
rb_test.loads(checkpoint_path)
4271+
assert_allclose_td(rb_test[:], rb[:])
4272+
41554273

41564274
@pytest.mark.skipif(not _has_ray, reason="ray required for this test.")
41574275
class TestRayRB:

torchrl/data/replay_buffers/checkpointers.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,82 @@ class TensorStorageCheckpointer(StorageCheckpointerBase):
331331
This class will call save and load hooks if provided. These hooks should take as input the
332332
data being transformed as well as the path where the data should be saved.
333333
334+
This checkpointer supports incremental saves when checkpointing to the same path repeatedly.
335+
Only the data that changed since the last checkpoint is written, significantly reducing
336+
checkpoint time for large buffers. This is controlled by the ``incremental`` parameter.
337+
338+
Keyword Args:
339+
incremental (bool, optional): if ``True``, enables incremental checkpointing where only
340+
modified data is saved when checkpointing to a path that already contains a checkpoint.
341+
This can dramatically reduce checkpoint time for large buffers with frequent saves.
342+
Defaults to ``False`` for backward compatibility.
343+
334344
"""
335345

346+
def __init__(self, *, incremental: bool = False):
347+
super().__init__()
348+
self.incremental = incremental
349+
350+
def _compute_dirty_range(
351+
self, last_checkpoint_cursor, current_cursor, max_size, is_full
352+
):
353+
"""Compute the range of indices that changed since the last checkpoint.
354+
355+
Args:
356+
last_checkpoint_cursor: Cursor position at the time of the last checkpoint.
357+
current_cursor: Current cursor position (where next write will go).
358+
max_size: Maximum size of the storage along dimension 0.
359+
is_full: Whether the storage is completely full.
360+
361+
Returns:
362+
A tuple (start, end) representing the dirty range, or None if a full save is needed.
363+
The range is [start, end) (end-exclusive).
364+
"""
365+
if last_checkpoint_cursor is None:
366+
# First checkpoint, need full save
367+
return None
368+
369+
if current_cursor == last_checkpoint_cursor:
370+
# Cursor hasn't moved. But if the buffer is full, it might have wrapped around
371+
# completely (wrote exactly max_size elements since last checkpoint).
372+
# In that case, we need a full save.
373+
# If buffer is not full and cursor hasn't moved, no changes.
374+
if is_full:
375+
# Could have wrapped around completely - do full save to be safe
376+
return None
377+
# No changes since last checkpoint
378+
return (0, 0)
379+
380+
if current_cursor > last_checkpoint_cursor:
381+
# Simple case: no wrap-around
382+
return (last_checkpoint_cursor, current_cursor)
383+
384+
# Wrap-around occurred: current_cursor < last_checkpoint_cursor
385+
# This means we wrote from last_checkpoint_cursor to max_size, then from 0 to current_cursor
386+
# For simplicity, we do a full save on wrap-around since it's complex to handle
387+
# and wrap-around typically means most of the buffer changed anyway
388+
return None
389+
390+
def _save_incremental(self, storage, _storage, path, dirty_range):
391+
"""Save only the dirty range to existing memmap files.
392+
393+
Args:
394+
storage: The storage object.
395+
_storage: The underlying tensor collection.
396+
path: Path to the checkpoint directory.
397+
dirty_range: Tuple (start, end) of indices to save.
398+
"""
399+
start, end = dirty_range
400+
if start == end:
401+
# No changes to save
402+
return
403+
404+
# Load existing memmap at path
405+
existing = TensorDict.load_memmap(path)
406+
407+
# Update only the dirty indices
408+
existing[start:end].update_(_storage[start:end])
409+
336410
def dumps(self, storage, path):
337411
path = Path(path)
338412
path.mkdir(exist_ok=True)
@@ -344,6 +418,13 @@ def dumps(self, storage, path):
344418

345419
self._set_hooks_shift_is_full(storage)
346420

421+
# Compute current cursor position for checkpoint tracking
422+
last_cursor = storage._last_cursor
423+
if last_cursor is not None:
424+
current_checkpoint_cursor = self._get_shift_from_last_cursor(last_cursor)
425+
else:
426+
current_checkpoint_cursor = 0
427+
347428
for hook in self._save_hooks:
348429
_storage = hook(_storage, path=path)
349430
if is_tensor_collection(_storage):
@@ -353,8 +434,28 @@ def dumps(self, storage, path):
353434
and Path(_storage.saved_path).absolute() == Path(path).absolute()
354435
):
355436
_storage.memmap_refresh_()
437+
elif (
438+
self.incremental
439+
and (path / "storage_metadata.json").exists()
440+
and storage._last_checkpoint_cursor is not None
441+
):
442+
# Incremental save: only save what changed
443+
dirty_range = self._compute_dirty_range(
444+
storage._last_checkpoint_cursor,
445+
current_checkpoint_cursor,
446+
storage._max_size_along_dim0(),
447+
storage._is_full,
448+
)
449+
if dirty_range is not None:
450+
self._save_incremental(storage, _storage, path, dirty_range)
451+
else:
452+
# Wrap-around or other case requiring full save
453+
_storage.memmap(
454+
path,
455+
copy_existing=True,
456+
)
356457
else:
357-
# try to load the path and overwrite.
458+
# Full save (first checkpoint or incremental disabled)
358459
_storage.memmap(
359460
path,
360461
copy_existing=True, # num_threads=torch.get_num_threads()
@@ -364,12 +465,16 @@ def dumps(self, storage, path):
364465
_save_pytree(_storage, metadata, path)
365466
is_pytree = True
366467

468+
# Update the checkpoint cursor for next incremental save
469+
storage._last_checkpoint_cursor = current_checkpoint_cursor
470+
367471
with open(path / "storage_metadata.json", "w") as file:
368472
json.dump(
369473
{
370474
"metadata": metadata,
371475
"is_pytree": is_pytree,
372476
"len": storage._len,
477+
"last_checkpoint_cursor": current_checkpoint_cursor,
373478
},
374479
file,
375480
)
@@ -453,6 +558,11 @@ def loads(self, storage, path):
453558
storage._storage.copy_(_storage)
454559
storage._len = _len
455560

561+
# Restore checkpoint cursor for incremental saves
562+
last_checkpoint_cursor = metadata.get("last_checkpoint_cursor")
563+
if last_checkpoint_cursor is not None:
564+
storage._last_checkpoint_cursor = last_checkpoint_cursor
565+
456566

457567
class FlatStorageCheckpointer(TensorStorageCheckpointer):
458568
"""Saves the storage in a compact form, saving space on the TED format.
@@ -539,6 +649,7 @@ def __init__(
539649
**kwargs,
540650
):
541651
StorageCheckpointerBase.__init__(self)
652+
self.incremental = False # H5 does not support incremental saves
542653
ted2_kwargs = kwargs
543654
if done_keys is not None:
544655
ted2_kwargs["done_keys"] = done_keys

torchrl/data/replay_buffers/storages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ def __init__(
733733
)
734734
self._storage = storage
735735
self._last_cursor = None
736+
self._last_checkpoint_cursor = None
736737
self.__dict__["_storage_keys"] = None
737738

738739
@property

0 commit comments

Comments
 (0)