-
Notifications
You must be signed in to change notification settings - Fork 45
Open
Description
I'm trying to evaluate NNN against conventional KNN on a simple test case.
The test case is to find the 5 nearest neighbour for a permutation of indices (for easy intuitive verification).
The problem is that the aggregate output is outputting the same value for all 5 neighbours.
Problem setup:
import torch
import non_local
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N = 50
nn = non_local.N3AggregationBase(5, temp_opt={"external_temp": False})
x = torch.tensor(np.random.permutation(list(range(N))), dtype=torch.float, requires_grad=True)
x = x.reshape(1, N, 1).to(device)
xe = x
ye = xe
I = torch.tensor(list(range(N)), dtype=torch.long).repeat(N, 1).reshape(1, N, N).to(device)
z = nn(x, xe, ye, I)
Where the aggregate output z is
tensor([[[[10.0001, 10.0001, 10.0001, 10.0001, 10.0001]],
[[42.0001, 42.0001, 42.0001, 42.0001, 42.0001]],
[[22.0001, 22.0001, 22.0001, 22.0001, 22.0001]],
...
Is this supposed to be the case and I'm interpreting the result wrong? If so then what is the aggregate output z supposed to represent?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels