Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 15 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ conda create -n polygraph-benchmark python=3.10
conda activate polygraph-benchmark
```

Then install
Then install
```bash
pip install -e .
```

If you'd like to use SBM graph dataset validation with `graph_tool`, use a mamba or pixi environment. More information is available in the documentation.
If you'd like to use SBM graph dataset validation with `graph_tool`, use a mamba or pixi environment. More information is available in the documentation.


PolyGraph is a Python library for evaluating graph generative models by providing standardized datasets and metrics
(including PolyGraphScore).


PolyGraph is a Python library for evaluating graph generative models by providing standardized datasets and metrics
(including PolyGraphScore).

## At a glance

Expand All @@ -37,33 +37,21 @@ Here are a set of datasets and metrics this library provides:
- PolyGraphScore: `StandardPGS`, `MolecularPGS` (for molecule descriptors).
- Validation/Uniqueness/Novelty: `VUN`.
- Uncertainty quantification for benchmarking (`GaussianTVMMD2BenchmarkInterval`, `RBFMMD2Benchmark`, `PGS5Interval`)
- 🧩 **Interoperability**: works with PyTorch Geometric and NetworkX; caching via `POLYGRAPH_CACHE_DIR`. Works on Apple Silicon Macs and Linux.
- 🧩 **Extendable**: Users can instantiate custom metrics by specifying descriptors, kernels, or classifiers (`PolyGraphScore`, `DescriptorMMD2`). PolyGraph defines all necessary interfaces but imposes no requirements on the data type of graph objects.
- ⚙️ **Interoperability**: Works on Apple Silicon Macs and Linux.
- ✅ **Tested, type checked and documented**


<details>
<summary><strong>⚠️ Important - Dataset Usage Warning</strong></summary>

**To help reproduce previous results, we provide the following datasets:**
- `PlanarGraphDataset`
- `SBMGraphDataset`
- `LobsterGraphDataset`

But they should not be used for benchmarking, due to unreliable metric estimates (see our paper for more details).

We provide larger datasets that should be used instead:
- `PlanarLGraphDataset`
- `SBMLGraphDataset`
- `LobsterLGraphDataset`

</details>

## Tutorial

Our [demo script](polygraph_demo.py) showcases some features of our library in action.
Our [demo script](demo_polygraph.py) showcases some basic features of our library in action.
For more advanced usage (namely, defining custom metrics), we refer to our [second demo script](demo_custom_metrics.py).


### Datasets
Instantiate a benchmark dataset as follows:

```python
import networkx as nx
from polygraph.datasets import PlanarGraphDataset
Expand All @@ -74,10 +62,12 @@ reference = PlanarGraphDataset("test").to_nx()
generated = [nx.erdos_renyi_graph(64, 0.1) for _ in range(40)]
```


### Metrics

#### Maximum Mean Discrepancy
To compute existing MMD2 formulations (e.g. based on the TV pseudokernel), one can use the following:

```python
from polygraph.metrics import GaussianTVMMD2Benchmark # Can also be RBFMMD2Benchmark

Expand All @@ -90,7 +80,7 @@ print(gtv_benchmark.compute(generated)) # {'orbit': ..., 'clustering': ..., 'de
Similarly, you can compute our proposed PolyGraphScore, like so:

```python
from polygraph.metrics import StandardPGS
from polygraph.metrics import StandardPGS

pgs = StandardPGS(reference)
print(pgs.compute(generated)) # {'polygraphscore': ..., 'polygraphscore_descriptor': ..., 'subscores': {'orbit': ..., }}
Expand All @@ -100,6 +90,7 @@ print(pgs.compute(generated)) # {'polygraphscore': ..., 'polygraphscore_descript

#### Validity, uniqueness and novelty
VUN values follow a similar interface:

```python
from polygraph.metrics import VUN
reference_ds = PlanarGraphDataset("test")
Expand All @@ -111,7 +102,7 @@ print(pgs.compute(generated)) # {'valid': ..., 'valid_unique_novel': ..., 'vali

For MMD and PGS, uncertainty quantifiation for the metrics are obtained through subsampling. For VUN, a confidence interval is obtained with a binomial test.

For `VUN`, the results can be obtained by specifying a confidence level when instantiating the metric.
For `VUN`, the results can be obtained by specifying a confidence level when instantiating the metric.

For the others, the `Interval` suffix references the class that implements subsampling.

Expand All @@ -130,4 +121,3 @@ for metric in tqdm(metrics):
generated,
)
```

82 changes: 82 additions & 0 deletions demo_custom_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Iterable, Collection
import numpy as np
import networkx as nx
from loguru import logger
from polygraph.datasets import PlanarGraphDataset, SBMGraphDataset
from polygraph.metrics.base import PolyGraphScore, DescriptorMMD2
from polygraph.utils.kernels import LinearKernel
from polygraph.utils.descriptors import ClusteringHistogram
from sklearn.linear_model import LogisticRegression

logger.disable("polygraph")


def betweenness_descriptor(graphs: Iterable[nx.Graph]) -> np.ndarray:
"""A custom graph descriptor that computes betweenness centrality.

This implements the polygraph.utils.descriptors.GraphDescriptor interface.
"""
histograms = []
for graph in graphs:
btw_values = list(nx.betweenness_centrality(graph).values())
histograms.append(
np.histogram(btw_values, bins=100, range=(0.0, 1.0), density=True)[
0
]
)
return np.stack(histograms, axis=0)


def calculate_custom_mmd(
reference: Collection[nx.Graph], generated: Collection[nx.Graph]
):
"""
Calculate a customized MMD between a reference dataset and a generated dataset.

This MMD uses a linear kernel based on betweenness centrality histograms.
It is estimated using the unbiased minimum variance estimator.
"""
print("Calculating custom MMD...")
mmd = DescriptorMMD2(
reference, kernel=LinearKernel(betweenness_descriptor), variant="umve"
)
print(f"Custom MMD: {mmd.compute(generated)} \n")


def calculate_custom_pgs(
reference: Collection[nx.Graph], generated: Collection[nx.Graph]
):
"""
Calculate a customized PGS between a reference dataset and a generated dataset.

This PGS uses betweenness centrality and clustering coefficients as graph descriptors. Instead of TabPFN, it uses logistic regression.

PolyGraphScore may be instantiated with any descriptors implementing the `polygraph.utils.descriptors.GraphDescriptor` interface
and any classifier implementing the `polygraph.metrics.base.polygraphscore.ClassifierProtocol` interface (i.e., the sklearn interface).
"""
print("Calculating custom PGS...")
pgs = PolyGraphScore(
reference,
descriptors={
"betweenness": betweenness_descriptor,
"clustering": ClusteringHistogram(bins=100),
},
classifier=LogisticRegression(),
)
result = pgs.compute(generated)

print(f"Overall PGS: {result['polygraphscore']:.6f}")
print(f"Score Descriptor: {result['polygraphscore_descriptor']}")
print("\nSubscores:")
for metric, score in result["subscores"].items():
print(f" {metric.capitalize()}: {score:.6f}")
print()


if __name__ == "__main__":
reference = list(PlanarGraphDataset("val").to_nx())
generated = list(SBMGraphDataset("val").to_nx())

calculate_custom_mmd(reference, generated)
print("=" * 50, "\n")
calculate_custom_pgs(reference, generated)
185 changes: 185 additions & 0 deletions demo_polygraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""polygraph_demo.py

In this file, we aim to demonstrate some of the features of the polygraph library.

"""

import os
from typing import List
import warnings

import networkx as nx
from appdirs import user_cache_dir
from loguru import logger

import polygraph
from polygraph.datasets import ProceduralPlanarGraphDataset
from polygraph.metrics import (
VUN,
GaussianTVMMD2Benchmark,
RBFMMD2Benchmark,
StandardPGS,
)


REF_SMILES = [
"Nc1ncnc2c1ncn2C1OC(CO)CC1F",
"C=CCc1c(OC(C)=O)c2cccnc2n(-c2ccccc2)c1=O",
"COc1ccc(Cc2cnc(N)nc2N)cc1OC",
"COc1cc(O)cc(CCc2ccc(O)c(OC)c2)c1",
"COc1c(C)cnc(CSc2nccn2C)c1C",
"O=c1cc(-c2ccncc2)nc(-c2cccnc2)[nH]1",
"O=c1c2ccccc2oc2nc3n(c(=O)c12)CCCS3",
"O=c1c2cc(Cl)ccc2oc2nc3n(c(=O)c12)CCCS3",
]
GEN_SMILES = [
"O=C(NC1CCN(C(=O)C2CC2)CC1)c1ccc(F)cc1",
"NC(=O)c1cccc2[nH]c(-c3ccc(O)cc3)nc12",
"CC(C)CCNC(=O)c1c[nH]c2ccccc2c1=O",
"CCOc1ccc2[nH]cc(C(=O)NCc3cccnc3)c(=O)c2c1",
"O=C(NCc1ccccc1)c1c[nH]c2c(F)cccc2c1=O",
"CC(C)c1cccc(C(C)C)c1NCc1ccccn1",
"CC1CCC(NC(=O)c2cc3ccccc3o2)CC1",
"COc1ccc2[nH]cc(CCNC(=O)c3ccco3)c2c1",
]

logger.disable("polygraph")
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


def _sample_generated_graphs(
n: int, num_nodes: int = 64, start_seed: int = 0
) -> List[nx.Graph]:
"""Create a small set of Erdos-Renyi graphs as a stand-in for a generator."""
return [
nx.erdos_renyi_graph(num_nodes, 0.1, seed=i + start_seed)
for i in range(n)
]


def data_location():
cache_dir = user_cache_dir(f"polygraph-{polygraph.__version__}", "ANON_ORG")
print(f"PolyGraph cache is typically located at: {cache_dir}")
print(
"It can be changed by setting the POLYGRAPH_CACHE_DIR environment variable."
)
print("Current value: ", os.environ.get("POLYGRAPH_CACHE_DIR"))


def get_example_datasets():
"""
Create a small set of Erdos-Renyi graphs as a stand-in for a generator and a reference dataset.
"""

reference_ds = list(
ProceduralPlanarGraphDataset("val", num_graphs=32).to_nx()
)
generated = _sample_generated_graphs(32)
print(
f"Reference graphs: {len(reference_ds)} | Generated graphs: {len(generated)}"
)
return reference_ds, generated


def calculate_gtv_mmd(reference, generated):
"""
Calculate the GTV pseudokernel MMD between a reference dataset and a generated dataset.
"""
print("GaussianTV MMD² Benchmark")
gtv = GaussianTVMMD2Benchmark(reference)
result = gtv.compute(generated)
print("Computed Gaussian TV pseudokernel MMD²:")
for metric, score in result.items():
print(f" {metric.capitalize()}: {score:.6f}")
print()


def calculate_rbf_mmd(reference, generated):
"""
Calculate the RBF MMD between a reference dataset and a generated dataset.
"""
print("RBF MMD² Benchmark")
rbf = RBFMMD2Benchmark(reference)
result = rbf.compute(generated)
print("Computed RBF MMD²:")
for metric, score in result.items():
print(f" {metric.capitalize()}: {score:.6f}")
print()


def calculate_pgs(reference, generated):
"""
Calculate the standard PolyGraphScore between a reference dataset and a generated dataset.
"""
print("PolyGraphScore (StandardPGS)")
pgs = StandardPGS(reference)
result = pgs.compute(generated)
print(f"Overall PGS: {result['polygraphscore']:.6f}")
print(f"Most powerful descriptor: {result['polygraphscore_descriptor']}")
print("Subscores:")
for metric, score in result["subscores"].items():
print(f" {metric.capitalize()}: {score:.6f}")
print()


def calculate_molecule_pgs(ref_smiles, gen_smiles):
"""
Calculate the PolyGraphScore between a reference dataset of molecules and a generated dataset of molecules.
"""
from polygraph.metrics.molecule_pgs import MoleculePGS
import rdkit.Chem

ref_mols = [rdkit.Chem.MolFromSmiles(smiles) for smiles in ref_smiles]
gen_mols = [rdkit.Chem.MolFromSmiles(smiles) for smiles in gen_smiles]

print(
f"PolyGraphScore (MoleculePGS) between {len(ref_mols)} reference and {len(gen_mols)} generated molecules:"
)
pgs = MoleculePGS(ref_mols)
result = pgs.compute(gen_mols)
print(f"Overall MoleculePGS: {result['polygraphscore']:.6f}")
print(f"Most powerful descriptor: {result['polygraphscore_descriptor']}")
print("Subscores:")
for metric, score in result["subscores"].items():
print(f" {metric.replace('_', ' ').title()}: {score:.6f}")
print()


def calculate_vun(reference, generated):
"""
Calculate the VUN between a reference dataset and a generated dataset.
"""
ds = ProceduralPlanarGraphDataset("val", num_graphs=1)
validity_fn = ds.is_valid if reference is not None else None
print("VUN")
vun = VUN(reference, validity_fn=validity_fn)
result = vun.compute(generated)
print("Computed VUN:")
for metric, score in result.items():
print(f" {metric.replace('_', ' ').title()}: {score:.6f}")
print()


def main():
print("=== PolyGraph Demo ===")

# Data location-related information
data_location()
reference, generated = get_example_datasets()
print()

calculate_gtv_mmd(reference, generated)
calculate_rbf_mmd(reference, generated)
calculate_pgs(reference, generated)
calculate_vun(reference, generated)
calculate_molecule_pgs(REF_SMILES, GEN_SMILES)

print("=== PolyGraph Demo End ===")


if __name__ == "__main__":
main()
Loading
Loading