Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/lerobot/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import glob
import importlib
import logging
import os
import queue
import shutil
import tempfile
Expand Down Expand Up @@ -260,11 +261,21 @@ def decode_video_frames_torchvision(


class VideoDecoderCache:
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
"""Thread-safe, process-local cache for video decoders to avoid expensive re-initialization."""

def __init__(self):
self._cache: dict[str, tuple[Any, Any]] = {}
self._lock = Lock()
self._owner_pid = os.getpid()

def _reset_if_forked(self) -> None:
# Drop entries inherited from a parent process. Don't close inherited
# file handles from a child — that can corrupt the parent's view; let GC
# reclaim them. Caller must hold the lock.
pid = os.getpid()
if pid != self._owner_pid:
self._cache = {}
self._owner_pid = pid

def get_decoder(self, video_path: str):
"""Get a cached decoder or create a new one."""
Expand All @@ -276,6 +287,7 @@ def get_decoder(self, video_path: str):
video_path = str(video_path)

with self._lock:
self._reset_if_forked()
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
decoder = VideoDecoder(file_handle, seek_mode="approximate")
Expand All @@ -286,16 +298,23 @@ def get_decoder(self, video_path: str):
def clear(self):
"""Clear the cache and close file handles."""
with self._lock:
for _, file_handle in self._cache.values():
file_handle.close()
if os.getpid() == self._owner_pid:
for _, file_handle in self._cache.values():
file_handle.close()
self._cache.clear()
self._owner_pid = os.getpid()

def size(self) -> int:
"""Return the number of cached decoders."""
with self._lock:
return len(self._cache)


def lerobot_worker_init_fn(worker_id: int) -> None:
"""``DataLoader(worker_init_fn=...)`` helper that resets the module-level torchcodec cache."""
_default_decoder_cache.clear()


class FrameTimestampError(ValueError):
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""

Expand Down