Skip to content

Commit ab412ff

Browse files
authored
Added molecule PGS, customized PGS. Modified README slightly (#21)
1 parent b704e08 commit ab412ff

File tree

5 files changed

+282
-135
lines changed

5 files changed

+282
-135
lines changed

README.md

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@ conda create -n polygraph-benchmark python=3.10
1111
conda activate polygraph-benchmark
1212
```
1313

14-
Then install
14+
Then install
1515
```bash
1616
pip install -e .
1717
```
1818

19-
If you'd like to use SBM graph dataset validation with `graph_tool`, use a mamba or pixi environment. More information is available in the documentation.
19+
If you'd like to use SBM graph dataset validation with `graph_tool`, use a mamba or pixi environment. More information is available in the documentation.
2020

2121

22+
PolyGraph is a Python library for evaluating graph generative models by providing standardized datasets and metrics
23+
(including PolyGraphScore).
2224

2325

24-
PolyGraph is a Python library for evaluating graph generative models by providing standardized datasets and metrics
25-
(including PolyGraphScore).
2626

2727
## At a glance
2828

@@ -37,33 +37,21 @@ Here are a set of datasets and metrics this library provides:
3737
- PolyGraphScore: `StandardPGS`, `MolecularPGS` (for molecule descriptors).
3838
- Validation/Uniqueness/Novelty: `VUN`.
3939
- Uncertainty quantification for benchmarking (`GaussianTVMMD2BenchmarkInterval`, `RBFMMD2Benchmark`, `PGS5Interval`)
40-
- 🧩 **Interoperability**: works with PyTorch Geometric and NetworkX; caching via `POLYGRAPH_CACHE_DIR`. Works on Apple Silicon Macs and Linux.
40+
- 🧩 **Extendable**: Users can instantiate custom metrics by specifying descriptors, kernels, or classifiers (`PolyGraphScore`, `DescriptorMMD2`). PolyGraph defines all necessary interfaces but imposes no requirements on the data type of graph objects.
41+
- ⚙️ **Interoperability**: Works on Apple Silicon Macs and Linux.
4142
-**Tested, type checked and documented**
4243

4344

44-
<details>
45-
<summary><strong>⚠️ Important - Dataset Usage Warning</strong></summary>
46-
47-
**To help reproduce previous results, we provide the following datasets:**
48-
- `PlanarGraphDataset`
49-
- `SBMGraphDataset`
50-
- `LobsterGraphDataset`
51-
52-
But they should not be used for benchmarking, due to unreliable metric estimates (see our paper for more details).
53-
54-
We provide larger datasets that should be used instead:
55-
- `PlanarLGraphDataset`
56-
- `SBMLGraphDataset`
57-
- `LobsterLGraphDataset`
58-
59-
</details>
6045

6146
## Tutorial
6247

63-
Our [demo script](polygraph_demo.py) showcases some features of our library in action.
48+
Our [demo script](demo_polygraph.py) showcases some basic features of our library in action.
49+
For more advanced usage (namely, defining custom metrics), we refer to our [second demo script](demo_custom_metrics.py).
50+
6451

6552
### Datasets
6653
Instantiate a benchmark dataset as follows:
54+
6755
```python
6856
import networkx as nx
6957
from polygraph.datasets import PlanarGraphDataset
@@ -74,10 +62,12 @@ reference = PlanarGraphDataset("test").to_nx()
7462
generated = [nx.erdos_renyi_graph(64, 0.1) for _ in range(40)]
7563
```
7664

65+
7766
### Metrics
7867

7968
#### Maximum Mean Discrepancy
8069
To compute existing MMD2 formulations (e.g. based on the TV pseudokernel), one can use the following:
70+
8171
```python
8272
from polygraph.metrics import GaussianTVMMD2Benchmark # Can also be RBFMMD2Benchmark
8373

@@ -90,7 +80,7 @@ print(gtv_benchmark.compute(generated)) # {'orbit': ..., 'clustering': ..., 'de
9080
Similarly, you can compute our proposed PolyGraphScore, like so:
9181

9282
```python
93-
from polygraph.metrics import StandardPGS
83+
from polygraph.metrics import StandardPGS
9484

9585
pgs = StandardPGS(reference)
9686
print(pgs.compute(generated)) # {'polygraphscore': ..., 'polygraphscore_descriptor': ..., 'subscores': {'orbit': ..., }}
@@ -100,6 +90,7 @@ print(pgs.compute(generated)) # {'polygraphscore': ..., 'polygraphscore_descript
10090

10191
#### Validity, uniqueness and novelty
10292
VUN values follow a similar interface:
93+
10394
```python
10495
from polygraph.metrics import VUN
10596
reference_ds = PlanarGraphDataset("test")
@@ -111,7 +102,7 @@ print(pgs.compute(generated)) # {'valid': ..., 'valid_unique_novel': ..., 'vali
111102

112103
For MMD and PGS, uncertainty quantifiation for the metrics are obtained through subsampling. For VUN, a confidence interval is obtained with a binomial test.
113104

114-
For `VUN`, the results can be obtained by specifying a confidence level when instantiating the metric.
105+
For `VUN`, the results can be obtained by specifying a confidence level when instantiating the metric.
115106

116107
For the others, the `Interval` suffix references the class that implements subsampling.
117108

@@ -130,4 +121,3 @@ for metric in tqdm(metrics):
130121
generated,
131122
)
132123
```
133-

demo_custom_metrics.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Iterable, Collection
2+
import numpy as np
3+
import networkx as nx
4+
from loguru import logger
5+
from polygraph.datasets import PlanarGraphDataset, SBMGraphDataset
6+
from polygraph.metrics.base import PolyGraphScore, DescriptorMMD2
7+
from polygraph.utils.kernels import LinearKernel
8+
from polygraph.utils.descriptors import ClusteringHistogram
9+
from sklearn.linear_model import LogisticRegression
10+
11+
logger.disable("polygraph")
12+
13+
14+
def betweenness_descriptor(graphs: Iterable[nx.Graph]) -> np.ndarray:
15+
"""A custom graph descriptor that computes betweenness centrality.
16+
17+
This implements the polygraph.utils.descriptors.GraphDescriptor interface.
18+
"""
19+
histograms = []
20+
for graph in graphs:
21+
btw_values = list(nx.betweenness_centrality(graph).values())
22+
histograms.append(
23+
np.histogram(btw_values, bins=100, range=(0.0, 1.0), density=True)[
24+
0
25+
]
26+
)
27+
return np.stack(histograms, axis=0)
28+
29+
30+
def calculate_custom_mmd(
31+
reference: Collection[nx.Graph], generated: Collection[nx.Graph]
32+
):
33+
"""
34+
Calculate a customized MMD between a reference dataset and a generated dataset.
35+
36+
This MMD uses a linear kernel based on betweenness centrality histograms.
37+
It is estimated using the unbiased minimum variance estimator.
38+
"""
39+
print("Calculating custom MMD...")
40+
mmd = DescriptorMMD2(
41+
reference, kernel=LinearKernel(betweenness_descriptor), variant="umve"
42+
)
43+
print(f"Custom MMD: {mmd.compute(generated)} \n")
44+
45+
46+
def calculate_custom_pgs(
47+
reference: Collection[nx.Graph], generated: Collection[nx.Graph]
48+
):
49+
"""
50+
Calculate a customized PGS between a reference dataset and a generated dataset.
51+
52+
This PGS uses betweenness centrality and clustering coefficients as graph descriptors. Instead of TabPFN, it uses logistic regression.
53+
54+
PolyGraphScore may be instantiated with any descriptors implementing the `polygraph.utils.descriptors.GraphDescriptor` interface
55+
and any classifier implementing the `polygraph.metrics.base.polygraphscore.ClassifierProtocol` interface (i.e., the sklearn interface).
56+
"""
57+
print("Calculating custom PGS...")
58+
pgs = PolyGraphScore(
59+
reference,
60+
descriptors={
61+
"betweenness": betweenness_descriptor,
62+
"clustering": ClusteringHistogram(bins=100),
63+
},
64+
classifier=LogisticRegression(),
65+
)
66+
result = pgs.compute(generated)
67+
68+
print(f"Overall PGS: {result['polygraphscore']:.6f}")
69+
print(f"Score Descriptor: {result['polygraphscore_descriptor']}")
70+
print("\nSubscores:")
71+
for metric, score in result["subscores"].items():
72+
print(f" {metric.capitalize()}: {score:.6f}")
73+
print()
74+
75+
76+
if __name__ == "__main__":
77+
reference = list(PlanarGraphDataset("val").to_nx())
78+
generated = list(SBMGraphDataset("val").to_nx())
79+
80+
calculate_custom_mmd(reference, generated)
81+
print("=" * 50, "\n")
82+
calculate_custom_pgs(reference, generated)

demo_polygraph.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
"""polygraph_demo.py
5+
6+
In this file, we aim to demonstrate some of the features of the polygraph library.
7+
8+
"""
9+
10+
import os
11+
from typing import List
12+
import warnings
13+
14+
import networkx as nx
15+
from appdirs import user_cache_dir
16+
from loguru import logger
17+
18+
import polygraph
19+
from polygraph.datasets import ProceduralPlanarGraphDataset
20+
from polygraph.metrics import (
21+
VUN,
22+
GaussianTVMMD2Benchmark,
23+
RBFMMD2Benchmark,
24+
StandardPGS,
25+
)
26+
27+
28+
REF_SMILES = [
29+
"Nc1ncnc2c1ncn2C1OC(CO)CC1F",
30+
"C=CCc1c(OC(C)=O)c2cccnc2n(-c2ccccc2)c1=O",
31+
"COc1ccc(Cc2cnc(N)nc2N)cc1OC",
32+
"COc1cc(O)cc(CCc2ccc(O)c(OC)c2)c1",
33+
"COc1c(C)cnc(CSc2nccn2C)c1C",
34+
"O=c1cc(-c2ccncc2)nc(-c2cccnc2)[nH]1",
35+
"O=c1c2ccccc2oc2nc3n(c(=O)c12)CCCS3",
36+
"O=c1c2cc(Cl)ccc2oc2nc3n(c(=O)c12)CCCS3",
37+
]
38+
GEN_SMILES = [
39+
"O=C(NC1CCN(C(=O)C2CC2)CC1)c1ccc(F)cc1",
40+
"NC(=O)c1cccc2[nH]c(-c3ccc(O)cc3)nc12",
41+
"CC(C)CCNC(=O)c1c[nH]c2ccccc2c1=O",
42+
"CCOc1ccc2[nH]cc(C(=O)NCc3cccnc3)c(=O)c2c1",
43+
"O=C(NCc1ccccc1)c1c[nH]c2c(F)cccc2c1=O",
44+
"CC(C)c1cccc(C(C)C)c1NCc1ccccn1",
45+
"CC1CCC(NC(=O)c2cc3ccccc3o2)CC1",
46+
"COc1ccc2[nH]cc(CCNC(=O)c3ccco3)c2c1",
47+
]
48+
49+
logger.disable("polygraph")
50+
warnings.filterwarnings("ignore", category=UserWarning)
51+
warnings.filterwarnings("ignore", category=FutureWarning)
52+
53+
54+
def _sample_generated_graphs(
55+
n: int, num_nodes: int = 64, start_seed: int = 0
56+
) -> List[nx.Graph]:
57+
"""Create a small set of Erdos-Renyi graphs as a stand-in for a generator."""
58+
return [
59+
nx.erdos_renyi_graph(num_nodes, 0.1, seed=i + start_seed)
60+
for i in range(n)
61+
]
62+
63+
64+
def data_location():
65+
cache_dir = user_cache_dir(f"polygraph-{polygraph.__version__}", "ANON_ORG")
66+
print(f"PolyGraph cache is typically located at: {cache_dir}")
67+
print(
68+
"It can be changed by setting the POLYGRAPH_CACHE_DIR environment variable."
69+
)
70+
print("Current value: ", os.environ.get("POLYGRAPH_CACHE_DIR"))
71+
72+
73+
def get_example_datasets():
74+
"""
75+
Create a small set of Erdos-Renyi graphs as a stand-in for a generator and a reference dataset.
76+
"""
77+
78+
reference_ds = list(
79+
ProceduralPlanarGraphDataset("val", num_graphs=32).to_nx()
80+
)
81+
generated = _sample_generated_graphs(32)
82+
print(
83+
f"Reference graphs: {len(reference_ds)} | Generated graphs: {len(generated)}"
84+
)
85+
return reference_ds, generated
86+
87+
88+
def calculate_gtv_mmd(reference, generated):
89+
"""
90+
Calculate the GTV pseudokernel MMD between a reference dataset and a generated dataset.
91+
"""
92+
print("GaussianTV MMD² Benchmark")
93+
gtv = GaussianTVMMD2Benchmark(reference)
94+
result = gtv.compute(generated)
95+
print("Computed Gaussian TV pseudokernel MMD²:")
96+
for metric, score in result.items():
97+
print(f" {metric.capitalize()}: {score:.6f}")
98+
print()
99+
100+
101+
def calculate_rbf_mmd(reference, generated):
102+
"""
103+
Calculate the RBF MMD between a reference dataset and a generated dataset.
104+
"""
105+
print("RBF MMD² Benchmark")
106+
rbf = RBFMMD2Benchmark(reference)
107+
result = rbf.compute(generated)
108+
print("Computed RBF MMD²:")
109+
for metric, score in result.items():
110+
print(f" {metric.capitalize()}: {score:.6f}")
111+
print()
112+
113+
114+
def calculate_pgs(reference, generated):
115+
"""
116+
Calculate the standard PolyGraphScore between a reference dataset and a generated dataset.
117+
"""
118+
print("PolyGraphScore (StandardPGS)")
119+
pgs = StandardPGS(reference)
120+
result = pgs.compute(generated)
121+
print(f"Overall PGS: {result['polygraphscore']:.6f}")
122+
print(f"Most powerful descriptor: {result['polygraphscore_descriptor']}")
123+
print("Subscores:")
124+
for metric, score in result["subscores"].items():
125+
print(f" {metric.capitalize()}: {score:.6f}")
126+
print()
127+
128+
129+
def calculate_molecule_pgs(ref_smiles, gen_smiles):
130+
"""
131+
Calculate the PolyGraphScore between a reference dataset of molecules and a generated dataset of molecules.
132+
"""
133+
from polygraph.metrics.molecule_pgs import MoleculePGS
134+
import rdkit.Chem
135+
136+
ref_mols = [rdkit.Chem.MolFromSmiles(smiles) for smiles in ref_smiles]
137+
gen_mols = [rdkit.Chem.MolFromSmiles(smiles) for smiles in gen_smiles]
138+
139+
print(
140+
f"PolyGraphScore (MoleculePGS) between {len(ref_mols)} reference and {len(gen_mols)} generated molecules:"
141+
)
142+
pgs = MoleculePGS(ref_mols)
143+
result = pgs.compute(gen_mols)
144+
print(f"Overall MoleculePGS: {result['polygraphscore']:.6f}")
145+
print(f"Most powerful descriptor: {result['polygraphscore_descriptor']}")
146+
print("Subscores:")
147+
for metric, score in result["subscores"].items():
148+
print(f" {metric.replace('_', ' ').title()}: {score:.6f}")
149+
print()
150+
151+
152+
def calculate_vun(reference, generated):
153+
"""
154+
Calculate the VUN between a reference dataset and a generated dataset.
155+
"""
156+
ds = ProceduralPlanarGraphDataset("val", num_graphs=1)
157+
validity_fn = ds.is_valid if reference is not None else None
158+
print("VUN")
159+
vun = VUN(reference, validity_fn=validity_fn)
160+
result = vun.compute(generated)
161+
print("Computed VUN:")
162+
for metric, score in result.items():
163+
print(f" {metric.replace('_', ' ').title()}: {score:.6f}")
164+
print()
165+
166+
167+
def main():
168+
print("=== PolyGraph Demo ===")
169+
170+
# Data location-related information
171+
data_location()
172+
reference, generated = get_example_datasets()
173+
print()
174+
175+
calculate_gtv_mmd(reference, generated)
176+
calculate_rbf_mmd(reference, generated)
177+
calculate_pgs(reference, generated)
178+
calculate_vun(reference, generated)
179+
calculate_molecule_pgs(REF_SMILES, GEN_SMILES)
180+
181+
print("=== PolyGraph Demo End ===")
182+
183+
184+
if __name__ == "__main__":
185+
main()

0 commit comments

Comments
 (0)