-
Notifications
You must be signed in to change notification settings - Fork 4k
Description
π The feature, motivation and pitch
π Motivation
Standard graph construction in PyG often relies on deterministic spatial methods like knn_graph or radius-based queries. However, recent work in generative protein design (notably the Chroma model) demonstrates the power of Random Graph Neural Networks that scale sub-quadratically (
Here is the doi of the paper: https://doi.org/10.1038/s41586-023-06728-8
Ingraham, J. B., et al. "Illuminating protein space with a programmable generative model." Nature (2023). Supplementary Information, Section E: Random Graph Neural Networks.
Algorithm 1: Random graph generation.
The core idea is inspired by fast N-body methods (like the Barnes-Hut algorithm) where distant interactions are modeled more coarsely. By combining deterministic local edges with stochastic long-range edges sampled via a distance-based propensity, models can process massive systems (e.g., 60,000 residues) on standard hardware.
π‘ Proposed Feature
I propose adding a utility function (e.g., torch_geometric.nn.random_graph) that implements Algorithm 1: Random graph generation from the Chroma model.
Mathematical Logic:
For each node
- Calculating inter-node distances
$D_{ij}$ . - Defining an attachment propensity score
$c(D_{ij})$ . For protein structures, the authors recommend inverse cubic propensity ($D_{ij}^{-3}$ ), which translates to$c(D_{ij}) = -3 \log(D_{ij})$ in log-space. - Sampling uniform noise
$U_{ij} \sim \text{Uniform}(0, 1)$ and calculating perturbed log-probabilities$Z_{ij}$ using the Gumbel Top-k trick:
$$Z_{ij} = \lambda_{\mathcal{G}} c(D_{ij}) - \log(-\log(U_{ij}))$$ - Selecting the top
$k$ edges for each node based on these scores.
π Implementation Detail
Example Interface:
def random_graph(x, k_random=40, lambda_g=1.0, seed=None, device="cpu"):
"""
Random Graph Generation
Args:
x: Tensor of node coordinates (N, 3)
k_random: Number of stochastic long-range edges
lambda_g: Inverse temperature hyperparameter (control how distance affects the outcome, higher lambda_g prioritizes closer distance)
"""
x = torch.tensor(x, device=device)
N = x.size(0)
device = x.device
# Stochastic Long-Range Edges
# Calculate all-to-all Euclidean distances D_ij
dist_matrix = torch.cdist(x, x)
# Define inverse cubic propensity: log p β -3 * log(D)
# Add epsilon to avoid log(0) for self-loops
eps = 1e-6
c_dist = -3.0 * torch.log(dist_matrix + eps)
# Sample uniform noise U_ij ~ Uniform(0,1)
if seed is None:
pass
else:
torch.manual_seed(seed)
U = torch.rand_like(dist_matrix)
# Calculate perturbed log probabilities Z_ij with Gumbel noise
# Z_ij = lambda_g * c(D_ij) - log(-log(U_ij))
gumbel_noise = -torch.log(-torch.log(U + eps) + eps)
Z = (lambda_g * c_dist) + gumbel_noise
# Mask self-loops so a node doesn't connect to itself randomly
Z.fill_diagonal_(float('-inf'))
# Select top k_random neighbors per node
_, top_indices = torch.topk(Z, k=k_random, dim=1)
# Convert to COO format (edge_index)
row = torch.arange(N, device=device).view(-1, 1).repeat(1, k_random).view(-1) # Something like [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, ... N-1, N-1, N-1, N-1]
col = top_indices.view(-1) # Something like [1, 5, 4, 6, 2, 3, 4, 2, 3, 5, 8, 15, ...]
edge_index = torch.stack([row, col], dim=0).to(x.device)
# Clean up: remove duplicates and self-loops
edge_index, _ = remove_self_loops(edge_index)
return edge_indexAlternatives
No response
Additional context
No response