@@ -331,8 +331,82 @@ class TensorStorageCheckpointer(StorageCheckpointerBase):
331331 This class will call save and load hooks if provided. These hooks should take as input the
332332 data being transformed as well as the path where the data should be saved.
333333
334+ This checkpointer supports incremental saves when checkpointing to the same path repeatedly.
335+ Only the data that changed since the last checkpoint is written, significantly reducing
336+ checkpoint time for large buffers. This is controlled by the ``incremental`` parameter.
337+
338+ Keyword Args:
339+ incremental (bool, optional): if ``True``, enables incremental checkpointing where only
340+ modified data is saved when checkpointing to a path that already contains a checkpoint.
341+ This can dramatically reduce checkpoint time for large buffers with frequent saves.
342+ Defaults to ``False`` for backward compatibility.
343+
334344 """
335345
346+ def __init__ (self , * , incremental : bool = False ):
347+ super ().__init__ ()
348+ self .incremental = incremental
349+
350+ def _compute_dirty_range (
351+ self , last_checkpoint_cursor , current_cursor , max_size , is_full
352+ ):
353+ """Compute the range of indices that changed since the last checkpoint.
354+
355+ Args:
356+ last_checkpoint_cursor: Cursor position at the time of the last checkpoint.
357+ current_cursor: Current cursor position (where next write will go).
358+ max_size: Maximum size of the storage along dimension 0.
359+ is_full: Whether the storage is completely full.
360+
361+ Returns:
362+ A tuple (start, end) representing the dirty range, or None if a full save is needed.
363+ The range is [start, end) (end-exclusive).
364+ """
365+ if last_checkpoint_cursor is None :
366+ # First checkpoint, need full save
367+ return None
368+
369+ if current_cursor == last_checkpoint_cursor :
370+ # Cursor hasn't moved. But if the buffer is full, it might have wrapped around
371+ # completely (wrote exactly max_size elements since last checkpoint).
372+ # In that case, we need a full save.
373+ # If buffer is not full and cursor hasn't moved, no changes.
374+ if is_full :
375+ # Could have wrapped around completely - do full save to be safe
376+ return None
377+ # No changes since last checkpoint
378+ return (0 , 0 )
379+
380+ if current_cursor > last_checkpoint_cursor :
381+ # Simple case: no wrap-around
382+ return (last_checkpoint_cursor , current_cursor )
383+
384+ # Wrap-around occurred: current_cursor < last_checkpoint_cursor
385+ # This means we wrote from last_checkpoint_cursor to max_size, then from 0 to current_cursor
386+ # For simplicity, we do a full save on wrap-around since it's complex to handle
387+ # and wrap-around typically means most of the buffer changed anyway
388+ return None
389+
390+ def _save_incremental (self , storage , _storage , path , dirty_range ):
391+ """Save only the dirty range to existing memmap files.
392+
393+ Args:
394+ storage: The storage object.
395+ _storage: The underlying tensor collection.
396+ path: Path to the checkpoint directory.
397+ dirty_range: Tuple (start, end) of indices to save.
398+ """
399+ start , end = dirty_range
400+ if start == end :
401+ # No changes to save
402+ return
403+
404+ # Load existing memmap at path
405+ existing = TensorDict .load_memmap (path )
406+
407+ # Update only the dirty indices
408+ existing [start :end ].update_ (_storage [start :end ])
409+
336410 def dumps (self , storage , path ):
337411 path = Path (path )
338412 path .mkdir (exist_ok = True )
@@ -344,6 +418,13 @@ def dumps(self, storage, path):
344418
345419 self ._set_hooks_shift_is_full (storage )
346420
421+ # Compute current cursor position for checkpoint tracking
422+ last_cursor = storage ._last_cursor
423+ if last_cursor is not None :
424+ current_checkpoint_cursor = self ._get_shift_from_last_cursor (last_cursor )
425+ else :
426+ current_checkpoint_cursor = 0
427+
347428 for hook in self ._save_hooks :
348429 _storage = hook (_storage , path = path )
349430 if is_tensor_collection (_storage ):
@@ -353,8 +434,28 @@ def dumps(self, storage, path):
353434 and Path (_storage .saved_path ).absolute () == Path (path ).absolute ()
354435 ):
355436 _storage .memmap_refresh_ ()
437+ elif (
438+ self .incremental
439+ and (path / "storage_metadata.json" ).exists ()
440+ and storage ._last_checkpoint_cursor is not None
441+ ):
442+ # Incremental save: only save what changed
443+ dirty_range = self ._compute_dirty_range (
444+ storage ._last_checkpoint_cursor ,
445+ current_checkpoint_cursor ,
446+ storage ._max_size_along_dim0 (),
447+ storage ._is_full ,
448+ )
449+ if dirty_range is not None :
450+ self ._save_incremental (storage , _storage , path , dirty_range )
451+ else :
452+ # Wrap-around or other case requiring full save
453+ _storage .memmap (
454+ path ,
455+ copy_existing = True ,
456+ )
356457 else :
357- # try to load the path and overwrite.
458+ # Full save (first checkpoint or incremental disabled)
358459 _storage .memmap (
359460 path ,
360461 copy_existing = True , # num_threads=torch.get_num_threads()
@@ -364,12 +465,16 @@ def dumps(self, storage, path):
364465 _save_pytree (_storage , metadata , path )
365466 is_pytree = True
366467
468+ # Update the checkpoint cursor for next incremental save
469+ storage ._last_checkpoint_cursor = current_checkpoint_cursor
470+
367471 with open (path / "storage_metadata.json" , "w" ) as file :
368472 json .dump (
369473 {
370474 "metadata" : metadata ,
371475 "is_pytree" : is_pytree ,
372476 "len" : storage ._len ,
477+ "last_checkpoint_cursor" : current_checkpoint_cursor ,
373478 },
374479 file ,
375480 )
@@ -453,6 +558,11 @@ def loads(self, storage, path):
453558 storage ._storage .copy_ (_storage )
454559 storage ._len = _len
455560
561+ # Restore checkpoint cursor for incremental saves
562+ last_checkpoint_cursor = metadata .get ("last_checkpoint_cursor" )
563+ if last_checkpoint_cursor is not None :
564+ storage ._last_checkpoint_cursor = last_checkpoint_cursor
565+
456566
457567class FlatStorageCheckpointer (TensorStorageCheckpointer ):
458568 """Saves the storage in a compact form, saving space on the TED format.
@@ -539,6 +649,7 @@ def __init__(
539649 ** kwargs ,
540650 ):
541651 StorageCheckpointerBase .__init__ (self )
652+ self .incremental = False # H5 does not support incremental saves
542653 ted2_kwargs = kwargs
543654 if done_keys is not None :
544655 ted2_kwargs ["done_keys" ] = done_keys
0 commit comments