We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 40df93a commit fedc9abCopy full SHA for fedc9ab
src/scvi/external/mrvi_torch/_module.py
@@ -453,9 +453,10 @@ def __init__(
453
self.qz_kwargs = qz_kwargs
454
self.qu_kwargs = qu_kwargs
455
456
- # Register n_obs_per_sample as a buffer so it moves to GPU with the model
+ # Register n_obs_per_sample as a non-persistent buffer so it moves to GPU with the model
457
+ # but is excluded from state_dict to avoid key conflicts when loading via setup_datamodule
458
if n_obs_per_sample is not None:
- self.register_buffer("n_obs_per_sample", n_obs_per_sample)
459
+ self.register_buffer("n_obs_per_sample", n_obs_per_sample, persistent=False)
460
else:
461
self.n_obs_per_sample = None
462
0 commit comments