-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Description & Motivation
Currently, PyTorch Lightning assumes that Dataset and DataLoader parameters (e.g., sequence length, unrolling steps fit() call (unless you set a static policy using reload_dataloader_every_n_epochs in Trainer). When a user needs to adjust these parameters dynamically based on real-time metrics (such as curriculum learning or adaptive unrolling), they are forced into two suboptimal patterns:
- Manual Script Restarts: Stopping execution, saving a checkpoint, updating the configuration, and resuming. This introduces significant overhead and breaks the continuity of the training loop.
- Dataset Over-engineering: Building complex, stateful Dataset objects that internally track metrics and handle logic.
Pitch
I propose introducing a formal mechanism to trigger in-place DataLoader reconfiguration without exiting the fit() loop. This would likely manifest as a new hook or an enhancement to the existing reload_dataloaders_every_n_epochs functionality, allowing for:
- Metric-driven triggers: Re-initializing the
DataLoaderwhen a specific monitor value hits a threshold. - Parameter Injection: A way to pass updated hyper-parameters from the LightningModule or a Callback directly into the datamodule during the reload process.
Alternatives
No response
Additional context
Here a snapshot of how a callback could be implemented if the requested feature will be accepted
def on_train_batch_end(self, outputs, batch, batch_idx):
if self.trainer.callback_metrics['train_loss'] < 0.5: # this is just an example (more complex logic will be fine)
# Trigger a reload with updated sequence length
self.trainer.datamodule.unroll_steps += 1
self.trainer.manual_dataloader_reload()I already have a draft code for this improvement and it will be a pleasure for me to open a PR is the feature will be of your interest
cc @lantiga