Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ async def orchestrate(config: OrchestratorConfig):
if is_vlm:
vlm_cache = build_vlm_image_cache(train_rollouts, processor)
logger.info(
f"VLM timing: extract={vlm_cache.extract_time:.2f}s, preprocess={vlm_cache.preprocess_time:.2f}s"
f"VLM timing: extract={vlm_cache.extract_time:.2f}s, preprocess={vlm_cache.preprocess_time:.2f}s "
f"({vlm_cache.num_unique_images} unique images from {vlm_cache.num_unique_examples} examples)"
)
else:
vlm_cache = None
Expand Down
26 changes: 19 additions & 7 deletions src/prime_rl/orchestrator/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,22 @@ def _preprocess_images_batched(
logger = get_logger()
image_sizes = [(img.width, img.height) for img in images]

# Process images in chunks to avoid OOM
all_pixel_values_list = []
all_grid_thw_list = []
for i in range(0, len(images), chunk_size):
chunk = images[i : i + chunk_size]
# Process images in chunks to avoid OOM, parallelized across threads
# (PIL/numpy release the GIL so threads give real concurrency here)
chunks = [images[i : i + chunk_size] for i in range(0, len(images), chunk_size)]

def _process_chunk(chunk: list[Image.Image]) -> tuple[torch.Tensor, torch.Tensor]:
processed = processor.image_processor(images=chunk, return_tensors="pt")
all_pixel_values_list.append(processed["pixel_values"])
all_grid_thw_list.append(processed["image_grid_thw"])
return processed["pixel_values"], processed["image_grid_thw"]

if len(chunks) > 1:
with ThreadPoolExecutor(max_workers=min(len(chunks), 8)) as pool:
results = list(pool.map(_process_chunk, chunks))
else:
results = [_process_chunk(chunks[0])]

all_pixel_values_list = [r[0] for r in results]
all_grid_thw_list = [r[1] for r in results]

all_pixel_values = torch.cat(all_pixel_values_list, dim=0)
all_grid_thw = torch.cat(all_grid_thw_list, dim=0)
Expand Down Expand Up @@ -407,6 +415,7 @@ def __init__(
self._step_indices: dict[int, list[list[int]]] | None = None
self.cache = cache
self.num_unique_examples = num_unique_examples
self.num_unique_images = 0
self.extract_time = extract_time
self.preprocess_time = preprocess_time

Expand All @@ -416,6 +425,7 @@ def from_store(
store: _ImageStore | None,
step_indices: dict[int, list[list[int]]],
num_unique_examples: int,
num_unique_images: int,
extract_time: float,
preprocess_time: float,
) -> "VLMImageCache":
Expand All @@ -425,6 +435,7 @@ def from_store(
obj._step_indices = step_indices
obj.cache = {}
obj.num_unique_examples = num_unique_examples
obj.num_unique_images = num_unique_images
obj.extract_time = extract_time
obj.preprocess_time = preprocess_time
return obj
Expand Down Expand Up @@ -486,6 +497,7 @@ def build_vlm_image_cache(rollouts: list[vf.RolloutOutput], processor) -> VLMIma
store=store,
step_indices=step_indices,
num_unique_examples=len(unique_example_ids),
num_unique_images=len(all_images),
extract_time=extract_time,
preprocess_time=preprocess_time,
)
2 changes: 2 additions & 0 deletions tests/unit/orchestrator/test_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,6 +2064,7 @@ def test_vlm_image_cache_from_store():
store=store,
step_indices=step_indices,
num_unique_examples=1,
num_unique_images=2,
extract_time=0.0,
preprocess_time=0.0,
)
Expand Down Expand Up @@ -2100,6 +2101,7 @@ def test_vlm_image_cache_from_store_no_images():
store=None,
step_indices=step_indices,
num_unique_examples=1,
num_unique_images=0,
extract_time=0.0,
preprocess_time=0.0,
)
Expand Down