Skip to content

Stochastic Random Graph Generation via Distance-Weighted Gumbel Top-k SamplingΒ #10585

@LongSNU

Description

@LongSNU

πŸš€ The feature, motivation and pitch

πŸš€ Motivation

Standard graph construction in PyG often relies on deterministic spatial methods like knn_graph or radius-based queries. However, recent work in generative protein design (notably the Chroma model) demonstrates the power of Random Graph Neural Networks that scale sub-quadratically ($O(N \log N)$ or $O(N)$) while capturing global context.

Here is the doi of the paper: https://doi.org/10.1038/s41586-023-06728-8
Ingraham, J. B., et al. "Illuminating protein space with a programmable generative model." Nature (2023). Supplementary Information, Section E: Random Graph Neural Networks.
Algorithm 1: Random graph generation.

The core idea is inspired by fast N-body methods (like the Barnes-Hut algorithm) where distant interactions are modeled more coarsely. By combining deterministic local edges with stochastic long-range edges sampled via a distance-based propensity, models can process massive systems (e.g., 60,000 residues) on standard hardware.

πŸ’‘ Proposed Feature

I propose adding a utility function (e.g., torch_geometric.nn.random_graph) that implements Algorithm 1: Random graph generation from the Chroma model.

Mathematical Logic:
For each node $i$, the algorithm selects $k$ stochastic neighbors by:

  1. Calculating inter-node distances $D_{ij}$.
  2. Defining an attachment propensity score $c(D_{ij})$. For protein structures, the authors recommend inverse cubic propensity ($D_{ij}^{-3}$), which translates to $c(D_{ij}) = -3 \log(D_{ij})$ in log-space.
  3. Sampling uniform noise $U_{ij} \sim \text{Uniform}(0, 1)$ and calculating perturbed log-probabilities $Z_{ij}$ using the Gumbel Top-k trick:
    $$Z_{ij} = \lambda_{\mathcal{G}} c(D_{ij}) - \log(-\log(U_{ij}))$$
  4. Selecting the top $k$ edges for each node based on these scores.

πŸ›  Implementation Detail

Example Interface:

def random_graph(x, k_random=40, lambda_g=1.0, seed=None, device="cpu"):
    """
    Random Graph Generation
    
    Args:
        x: Tensor of node coordinates (N, 3)
        k_random: Number of stochastic long-range edges
        lambda_g: Inverse temperature hyperparameter (control how distance affects the outcome, higher lambda_g prioritizes closer distance)
    """
    x = torch.tensor(x, device=device)
    N = x.size(0)
    device = x.device

    # Stochastic Long-Range Edges
    # Calculate all-to-all Euclidean distances D_ij
    dist_matrix = torch.cdist(x, x)
    
    # Define inverse cubic propensity: log p ∝ -3 * log(D)
    # Add epsilon to avoid log(0) for self-loops
    eps = 1e-6
    c_dist = -3.0 * torch.log(dist_matrix + eps)
    
    # Sample uniform noise U_ij ~ Uniform(0,1)
    if seed is None:
        pass
    else:
        torch.manual_seed(seed)
    
    U = torch.rand_like(dist_matrix)
        
    # Calculate perturbed log probabilities Z_ij with Gumbel noise
    # Z_ij = lambda_g * c(D_ij) - log(-log(U_ij))
    gumbel_noise = -torch.log(-torch.log(U + eps) + eps)
    Z = (lambda_g * c_dist) + gumbel_noise
    
    # Mask self-loops so a node doesn't connect to itself randomly
    Z.fill_diagonal_(float('-inf'))
    
    # Select top k_random neighbors per node
    _, top_indices = torch.topk(Z, k=k_random, dim=1)
    
    # Convert to COO format (edge_index)
    row = torch.arange(N, device=device).view(-1, 1).repeat(1, k_random).view(-1) # Something like [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, ... N-1, N-1, N-1, N-1]
    col = top_indices.view(-1)                                                    # Something like [1, 5, 4, 6, 2, 3, 4, 2, 3, 5, 8, 15, ...]
    edge_index = torch.stack([row, col], dim=0).to(x.device)
    
    # Clean up: remove duplicates and self-loops
    edge_index, _ = remove_self_loops(edge_index)
    
    return edge_index

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions