Required prerequisites
What version of TorchOpt are you using?
0.7.3
System information
sys.version: 3.10.13 (main, Sep 18 2023, 17:18:13) [GCC 12.3.1 20230526]
sys.platform: linux
torchopt==0.7.3
torch==2.6.0
functorch==2.6.0
Problem description
Running the provided code shows that there is a slight difference between the torchopt updates and the torch updates (expected there to be none). The difference is larger with higher learning rates.
Reproducible example code
The Python snippets:
import torch
import torchopt
lr = 0.01
torch.manual_seed(1)
model = torch.nn.Linear(3, 1)
optim = torchopt.Adam(model.parameters(), lr=lr)
print(next(model.parameters()))
n_updates = 30
x = torch.rand((n_updates, 3), requires_grad=True)
for i in range(n_updates):
b_x = x[i]
y = torch.rand((1,), requires_grad=True)
out = model(b_x)
loss = ((out - y) ** 2).sum()
optim.zero_grad()
loss.backward()
optim.step()
params_1 = next(model.parameters()).detach()
print(next(model.parameters()))
torch.manual_seed(1)
model = torch.nn.Linear(3, 1)
optim = torch.optim.Adam(model.parameters(), lr=lr)
print(next(model.parameters()))
x = torch.rand((n_updates, 3), requires_grad=True)
for i in range(n_updates):
b_x = x[i]
y = torch.rand((1,), requires_grad=True)
out = model(b_x)
loss = ((out - y) ** 2).sum()
optim.zero_grad()
loss.backward()
optim.step()
params_2 = next(model.parameters()).detach()
print(next(model.parameters()))
print("All close:", torch.allclose(params_1, params_2))
print("Difference:", params_1 - params_2)
# Output:
# Parameter containing:
# tensor([[ 0.2975, -0.2548, -0.1119]], requires_grad=True)
# Parameter containing:
# tensor([[ 0.4131, -0.1184, 0.0238]], requires_grad=True)
# Parameter containing:
# tensor([[ 0.2975, -0.2548, -0.1119]], requires_grad=True)
# Parameter containing:
# tensor([[ 0.4131, -0.1184, 0.0238]], requires_grad=True)
# All close: False
# Difference: tensor([[-1.4901e-07, -4.3213e-07, -2.9616e-07]])
Steps to reproduce:
- Run provided code.
- Larger LRs cause larger divergences.
Traceback
Expected behavior
Expected that there is (almost) zero difference between the result of torchopt and torch
Additional context
No response
Required prerequisites
What version of TorchOpt are you using?
0.7.3
System information
sys.version: 3.10.13 (main, Sep 18 2023, 17:18:13) [GCC 12.3.1 20230526]
sys.platform: linux
torchopt==0.7.3
torch==2.6.0
functorch==2.6.0
Problem description
Running the provided code shows that there is a slight difference between the torchopt updates and the torch updates (expected there to be none). The difference is larger with higher learning rates.
Reproducible example code
The Python snippets:
Steps to reproduce:
Traceback
Expected behavior
Expected that there is (almost) zero difference between the result of torchopt and torch
Additional context
No response