Skip to content

Aggregate selects same element multiple times #9

@LemonPi

Description

@LemonPi

I'm trying to evaluate NNN against conventional KNN on a simple test case.
The test case is to find the 5 nearest neighbour for a permutation of indices (for easy intuitive verification).
The problem is that the aggregate output is outputting the same value for all 5 neighbours.

Problem setup:

import torch
import non_local
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

N = 50
nn = non_local.N3AggregationBase(5, temp_opt={"external_temp": False})

x = torch.tensor(np.random.permutation(list(range(N))), dtype=torch.float, requires_grad=True)
x = x.reshape(1, N, 1).to(device)
xe = x
ye = xe
I = torch.tensor(list(range(N)), dtype=torch.long).repeat(N, 1).reshape(1, N, N).to(device)

z = nn(x, xe, ye, I)

Where the aggregate output z is

tensor([[[[10.0001, 10.0001, 10.0001, 10.0001, 10.0001]],
         [[42.0001, 42.0001, 42.0001, 42.0001, 42.0001]],
         [[22.0001, 22.0001, 22.0001, 22.0001, 22.0001]],
...

Is this supposed to be the case and I'm interpreting the result wrong? If so then what is the aggregate output z supposed to represent?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions