Skip to content

Commit fedc9ab

Browse files
authored
fix: use persistent=False for n_obs_per_sample buffer in TorchMRVAE to avoid state_dict conflict on load of custom dataloader module (#3756)
1 parent 40df93a commit fedc9ab

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/scvi/external/mrvi_torch/_module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,10 @@ def __init__(
453453
self.qz_kwargs = qz_kwargs
454454
self.qu_kwargs = qu_kwargs
455455

456-
# Register n_obs_per_sample as a buffer so it moves to GPU with the model
456+
# Register n_obs_per_sample as a non-persistent buffer so it moves to GPU with the model
457+
# but is excluded from state_dict to avoid key conflicts when loading via setup_datamodule
457458
if n_obs_per_sample is not None:
458-
self.register_buffer("n_obs_per_sample", n_obs_per_sample)
459+
self.register_buffer("n_obs_per_sample", n_obs_per_sample, persistent=False)
459460
else:
460461
self.n_obs_per_sample = None
461462

0 commit comments

Comments
 (0)