Skip to content

[Possible bug]: RNG is reset for each parameter, causing identical perturbations for parameters with the same shape #1

@chenchenygu

Description

@chenchenygu

In the code for randomly perturbing model weights, a new torch.Generator is created and seeded with the same value inside the loop over parameters:

def perturb_self_weights(self, seed, noise_scale, negate=False):
self._set_seed(seed)
scale = float(noise_scale)
sign = -1.0 if negate else 1.0
for name, p in self.model_runner.model.named_parameters():
gen = torch.Generator(device=p.device)
gen.manual_seed(int(seed))
noise = torch.randn(p.shape, dtype=p.dtype, device=p.device, generator=gen)
if self._should_perturb(name):
p.data.add_(sign * scale * noise)
del noise

This means that

  • Parameters with the same shape receive identical perturbations.
  • Parameters with different shapes also draw values from the same re-seeded random number generator, so their perturbations will not be independent.

This behavior seems inconsistent with the paper, which states that the flattened weight vector $\theta \in \mathbb{R}^d$ is perturbed by a standard multivariate Gaussian noise vector $\epsilon \sim \mathcal{N}(0, I_d)$ multiplied by a scalar $\sigma$.

This same pattern appears in several places:

  • perturb_self_weights:
    for name, p in self.model_runner.model.named_parameters():
    gen = torch.Generator(device=p.device)
    gen.manual_seed(int(seed))
    noise = torch.randn(p.shape, dtype=p.dtype, device=p.device, generator=gen)
  • restore_self_weights:
    for name, p in self.model_runner.model.named_parameters():
    gen = torch.Generator(device=p.device)
    gen.manual_seed(int(seed))
    noise = torch.randn(p.shape, dtype=p.dtype, device=p.device, generator=gen)
  • update_weights_from_seeds:
    for i, seed in enumerate(seeds):
    self._set_seed(seed)
    gen = torch.Generator(device=p.device)
    gen.manual_seed(int(seed))
    # Generate noise (in native precision, usually float16/bfloat16)
    noise = torch.randn(p.shape, dtype=p.dtype, device=p.device, generator=gen)
  • dump_noise_for_seed:
    for name, p in self.model_runner.model.named_parameters():
    gen = torch.Generator(device=p.device)
    gen.manual_seed(int(seed))
    noise = torch.randn(p.shape, dtype=p.dtype, device=p.device, generator=gen)
  • apply_perturbation:
    self._set_seed(seed)
    for name, p in self.model_runner.model.named_parameters():
    # Restore base weights first
    p.data.copy_(self._base_weights[name])
    # Then apply perturbation (skip visual encoder)
    gen = torch.Generator(device=p.device)
    gen.manual_seed(int(seed))
    noise = torch.randn(p.shape, dtype=p.dtype, device=p.device, generator=gen)
  • apply_averaged_perturbations:
    for (seed, sigma), weight in zip(seeds_sigmas, weights):
    gen = torch.Generator(device=p.device)
    gen.manual_seed(int(seed))
    noise = torch.randn(p.shape, dtype=p.dtype, device=p.device, generator=gen)

There is a different pattern in utils/repro_seed.py, where the generator is seeded once before the loop over the parameters, which matches the behavior described in the paper. But it does not look like utils/repro_seed.py is used or referred to by anywhere in the repo.

def perturb_model_weights(model, seed, sigma, is_negative=False):
"""
Apply perturbation to model weights using a deterministic seed.
Args:
model: The model to perturb
seed: Random seed for reproducibility
sigma: Perturbation magnitude
is_negative: If True, apply -noise instead of +noise (for antithetic sampling)
"""
rng = torch.Generator()
rng.manual_seed(seed)
sign = -1.0 if is_negative else 1.0
with torch.no_grad():
for param in model.parameters():
if param.requires_grad:
noise = torch.randn(param.shape, generator=rng, dtype=param.dtype) * sigma * sign
param.add_(noise.to(param.device))

Could the authors clarify this behavior?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions