-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Open
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinglightningdatamodulepl.LightningDataModulepl.LightningDataModulever: 2.5.x
Description
Bug description
When loading from a checkpoint the setup_data function is called causing self._last_train_dl_reload_epoch to be updated and later the dataloader is not updated as expected causing reproducibility issues in training.
I have trained a model for 50 epochs and checkpointed it.
I have a reload option every 50 epochs.
When loading the model the setup_data function is called and the dataloader is reset on epoch 49 (although normally it would be reset on the 51st epoch, so epoch 50). So when epoch 50 starts the dataloader is not reloaded as only one epoch has passed.
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
import torch
from lightning import Trainer, LightningModule
from torch.utils.data import Dataset, DataLoader
CHECKPOINT_PATH = "stage1.ckpt"
DEBUG = False
class FakeDataset(Dataset):
def __init__(self):
self.data = [torch.zeros(3) for _ in range(10)]
self.labels = [torch.zeros(1) for _ in range(10)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
class FakeDataset2(FakeDataset):
def __init__(self):
super().__init__()
self.data = [torch.ones(3) for _ in range(20)]
self.labels = [torch.ones(1) for _ in range(20)]
class SimpleModule(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(3, 1)
self.first_stage_epochs: int = 2
self.second_stage_epochs: int = 10
def training_step(self, batch, batch_idx):
x, y = batch
if self.current_epoch < self.first_stage_epochs:
assert torch.all(x == 0), "Data in first stage should be zeros"
else:
assert torch.all(x == 1), "Data in second stage should be ones"
y_hat = self.layer(x)
loss = torch.nn.functional.mse_loss(y_hat, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.01)
def train_dataloader(self):
if self.current_epoch < self.first_stage_epochs:
dataset = FakeDataset()
else:
dataset = FakeDataset2()
return DataLoader(dataset, batch_size=1)
def on_train_epoch_end(self):
if self.current_epoch == self.first_stage_epochs - 1:
print("Completed first stage")
self.trainer.save_checkpoint(CHECKPOINT_PATH)
def on_train_start(self):
if self.current_epoch > 0 and DEBUG:
print("Resumed training for second stage")
# Fit Loop is calling the setup_data() during checkpoint restoration
# This is setting _last_train_dl_reload_epoch to current epoch - 1
self.trainer.fit_loop._last_train_dl_reload_epoch = 0
if __name__ == "__main__":
# Uncomment the following line for a quick fix
# DEBUG = True
dataset = FakeDataset()
model = SimpleModule()
trainer = Trainer(
max_epochs=model.first_stage_epochs + model.second_stage_epochs,
accelerator="gpu",
devices=1,
log_every_n_steps=1,
reload_dataloaders_every_n_epochs=model.first_stage_epochs,
)
#####################################################
# Train the first time both 1st stage and 2nd stage
#####################################################
trainer.fit(model)
print("Successfully trained first stage and second stage")
#####################################################
# Resume training from checkpoint and only train second stage
#####################################################
trainer.fit(model, ckpt_path=CHECKPOINT_PATH)Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinglightningdatamodulepl.LightningDataModulepl.LightningDataModulever: 2.5.x