diff --git a/README.md b/README.md index 3c8c7e22..188bab41 100644 --- a/README.md +++ b/README.md @@ -11,18 +11,18 @@ conda create -n polygraph-benchmark python=3.10 conda activate polygraph-benchmark ``` -Then install +Then install ```bash pip install -e . ``` -If you'd like to use SBM graph dataset validation with `graph_tool`, use a mamba or pixi environment. More information is available in the documentation. +If you'd like to use SBM graph dataset validation with `graph_tool`, use a mamba or pixi environment. More information is available in the documentation. +PolyGraph is a Python library for evaluating graph generative models by providing standardized datasets and metrics +(including PolyGraphScore). -PolyGraph is a Python library for evaluating graph generative models by providing standardized datasets and metrics -(including PolyGraphScore). ## At a glance @@ -37,33 +37,21 @@ Here are a set of datasets and metrics this library provides: - PolyGraphScore: `StandardPGS`, `MolecularPGS` (for molecule descriptors). - Validation/Uniqueness/Novelty: `VUN`. - Uncertainty quantification for benchmarking (`GaussianTVMMD2BenchmarkInterval`, `RBFMMD2Benchmark`, `PGS5Interval`) -- 🧩 **Interoperability**: works with PyTorch Geometric and NetworkX; caching via `POLYGRAPH_CACHE_DIR`. Works on Apple Silicon Macs and Linux. +- 🧩 **Extendable**: Users can instantiate custom metrics by specifying descriptors, kernels, or classifiers (`PolyGraphScore`, `DescriptorMMD2`). PolyGraph defines all necessary interfaces but imposes no requirements on the data type of graph objects. +- ⚙️ **Interoperability**: Works on Apple Silicon Macs and Linux. - ✅ **Tested, type checked and documented** -
-⚠️ Important - Dataset Usage Warning - -**To help reproduce previous results, we provide the following datasets:** -- `PlanarGraphDataset` -- `SBMGraphDataset` -- `LobsterGraphDataset` - -But they should not be used for benchmarking, due to unreliable metric estimates (see our paper for more details). - -We provide larger datasets that should be used instead: -- `PlanarLGraphDataset` -- `SBMLGraphDataset` -- `LobsterLGraphDataset` - -
## Tutorial -Our [demo script](polygraph_demo.py) showcases some features of our library in action. +Our [demo script](demo_polygraph.py) showcases some basic features of our library in action. +For more advanced usage (namely, defining custom metrics), we refer to our [second demo script](demo_custom_metrics.py). + ### Datasets Instantiate a benchmark dataset as follows: + ```python import networkx as nx from polygraph.datasets import PlanarGraphDataset @@ -74,10 +62,12 @@ reference = PlanarGraphDataset("test").to_nx() generated = [nx.erdos_renyi_graph(64, 0.1) for _ in range(40)] ``` + ### Metrics #### Maximum Mean Discrepancy To compute existing MMD2 formulations (e.g. based on the TV pseudokernel), one can use the following: + ```python from polygraph.metrics import GaussianTVMMD2Benchmark # Can also be RBFMMD2Benchmark @@ -90,7 +80,7 @@ print(gtv_benchmark.compute(generated)) # {'orbit': ..., 'clustering': ..., 'de Similarly, you can compute our proposed PolyGraphScore, like so: ```python -from polygraph.metrics import StandardPGS +from polygraph.metrics import StandardPGS pgs = StandardPGS(reference) print(pgs.compute(generated)) # {'polygraphscore': ..., 'polygraphscore_descriptor': ..., 'subscores': {'orbit': ..., }} @@ -100,6 +90,7 @@ print(pgs.compute(generated)) # {'polygraphscore': ..., 'polygraphscore_descript #### Validity, uniqueness and novelty VUN values follow a similar interface: + ```python from polygraph.metrics import VUN reference_ds = PlanarGraphDataset("test") @@ -111,7 +102,7 @@ print(pgs.compute(generated)) # {'valid': ..., 'valid_unique_novel': ..., 'vali For MMD and PGS, uncertainty quantifiation for the metrics are obtained through subsampling. For VUN, a confidence interval is obtained with a binomial test. -For `VUN`, the results can be obtained by specifying a confidence level when instantiating the metric. +For `VUN`, the results can be obtained by specifying a confidence level when instantiating the metric. For the others, the `Interval` suffix references the class that implements subsampling. @@ -130,4 +121,3 @@ for metric in tqdm(metrics): generated, ) ``` - diff --git a/demo_custom_metrics.py b/demo_custom_metrics.py new file mode 100644 index 00000000..f97fff0a --- /dev/null +++ b/demo_custom_metrics.py @@ -0,0 +1,82 @@ +from typing import Iterable, Collection +import numpy as np +import networkx as nx +from loguru import logger +from polygraph.datasets import PlanarGraphDataset, SBMGraphDataset +from polygraph.metrics.base import PolyGraphScore, DescriptorMMD2 +from polygraph.utils.kernels import LinearKernel +from polygraph.utils.descriptors import ClusteringHistogram +from sklearn.linear_model import LogisticRegression + +logger.disable("polygraph") + + +def betweenness_descriptor(graphs: Iterable[nx.Graph]) -> np.ndarray: + """A custom graph descriptor that computes betweenness centrality. + + This implements the polygraph.utils.descriptors.GraphDescriptor interface. + """ + histograms = [] + for graph in graphs: + btw_values = list(nx.betweenness_centrality(graph).values()) + histograms.append( + np.histogram(btw_values, bins=100, range=(0.0, 1.0), density=True)[ + 0 + ] + ) + return np.stack(histograms, axis=0) + + +def calculate_custom_mmd( + reference: Collection[nx.Graph], generated: Collection[nx.Graph] +): + """ + Calculate a customized MMD between a reference dataset and a generated dataset. + + This MMD uses a linear kernel based on betweenness centrality histograms. + It is estimated using the unbiased minimum variance estimator. + """ + print("Calculating custom MMD...") + mmd = DescriptorMMD2( + reference, kernel=LinearKernel(betweenness_descriptor), variant="umve" + ) + print(f"Custom MMD: {mmd.compute(generated)} \n") + + +def calculate_custom_pgs( + reference: Collection[nx.Graph], generated: Collection[nx.Graph] +): + """ + Calculate a customized PGS between a reference dataset and a generated dataset. + + This PGS uses betweenness centrality and clustering coefficients as graph descriptors. Instead of TabPFN, it uses logistic regression. + + PolyGraphScore may be instantiated with any descriptors implementing the `polygraph.utils.descriptors.GraphDescriptor` interface + and any classifier implementing the `polygraph.metrics.base.polygraphscore.ClassifierProtocol` interface (i.e., the sklearn interface). + """ + print("Calculating custom PGS...") + pgs = PolyGraphScore( + reference, + descriptors={ + "betweenness": betweenness_descriptor, + "clustering": ClusteringHistogram(bins=100), + }, + classifier=LogisticRegression(), + ) + result = pgs.compute(generated) + + print(f"Overall PGS: {result['polygraphscore']:.6f}") + print(f"Score Descriptor: {result['polygraphscore_descriptor']}") + print("\nSubscores:") + for metric, score in result["subscores"].items(): + print(f" {metric.capitalize()}: {score:.6f}") + print() + + +if __name__ == "__main__": + reference = list(PlanarGraphDataset("val").to_nx()) + generated = list(SBMGraphDataset("val").to_nx()) + + calculate_custom_mmd(reference, generated) + print("=" * 50, "\n") + calculate_custom_pgs(reference, generated) diff --git a/demo_polygraph.py b/demo_polygraph.py new file mode 100644 index 00000000..f0ad8bdc --- /dev/null +++ b/demo_polygraph.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""polygraph_demo.py + +In this file, we aim to demonstrate some of the features of the polygraph library. + +""" + +import os +from typing import List +import warnings + +import networkx as nx +from appdirs import user_cache_dir +from loguru import logger + +import polygraph +from polygraph.datasets import ProceduralPlanarGraphDataset +from polygraph.metrics import ( + VUN, + GaussianTVMMD2Benchmark, + RBFMMD2Benchmark, + StandardPGS, +) + + +REF_SMILES = [ + "Nc1ncnc2c1ncn2C1OC(CO)CC1F", + "C=CCc1c(OC(C)=O)c2cccnc2n(-c2ccccc2)c1=O", + "COc1ccc(Cc2cnc(N)nc2N)cc1OC", + "COc1cc(O)cc(CCc2ccc(O)c(OC)c2)c1", + "COc1c(C)cnc(CSc2nccn2C)c1C", + "O=c1cc(-c2ccncc2)nc(-c2cccnc2)[nH]1", + "O=c1c2ccccc2oc2nc3n(c(=O)c12)CCCS3", + "O=c1c2cc(Cl)ccc2oc2nc3n(c(=O)c12)CCCS3", +] +GEN_SMILES = [ + "O=C(NC1CCN(C(=O)C2CC2)CC1)c1ccc(F)cc1", + "NC(=O)c1cccc2[nH]c(-c3ccc(O)cc3)nc12", + "CC(C)CCNC(=O)c1c[nH]c2ccccc2c1=O", + "CCOc1ccc2[nH]cc(C(=O)NCc3cccnc3)c(=O)c2c1", + "O=C(NCc1ccccc1)c1c[nH]c2c(F)cccc2c1=O", + "CC(C)c1cccc(C(C)C)c1NCc1ccccn1", + "CC1CCC(NC(=O)c2cc3ccccc3o2)CC1", + "COc1ccc2[nH]cc(CCNC(=O)c3ccco3)c2c1", +] + +logger.disable("polygraph") +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +def _sample_generated_graphs( + n: int, num_nodes: int = 64, start_seed: int = 0 +) -> List[nx.Graph]: + """Create a small set of Erdos-Renyi graphs as a stand-in for a generator.""" + return [ + nx.erdos_renyi_graph(num_nodes, 0.1, seed=i + start_seed) + for i in range(n) + ] + + +def data_location(): + cache_dir = user_cache_dir(f"polygraph-{polygraph.__version__}", "ANON_ORG") + print(f"PolyGraph cache is typically located at: {cache_dir}") + print( + "It can be changed by setting the POLYGRAPH_CACHE_DIR environment variable." + ) + print("Current value: ", os.environ.get("POLYGRAPH_CACHE_DIR")) + + +def get_example_datasets(): + """ + Create a small set of Erdos-Renyi graphs as a stand-in for a generator and a reference dataset. + """ + + reference_ds = list( + ProceduralPlanarGraphDataset("val", num_graphs=32).to_nx() + ) + generated = _sample_generated_graphs(32) + print( + f"Reference graphs: {len(reference_ds)} | Generated graphs: {len(generated)}" + ) + return reference_ds, generated + + +def calculate_gtv_mmd(reference, generated): + """ + Calculate the GTV pseudokernel MMD between a reference dataset and a generated dataset. + """ + print("GaussianTV MMD² Benchmark") + gtv = GaussianTVMMD2Benchmark(reference) + result = gtv.compute(generated) + print("Computed Gaussian TV pseudokernel MMD²:") + for metric, score in result.items(): + print(f" {metric.capitalize()}: {score:.6f}") + print() + + +def calculate_rbf_mmd(reference, generated): + """ + Calculate the RBF MMD between a reference dataset and a generated dataset. + """ + print("RBF MMD² Benchmark") + rbf = RBFMMD2Benchmark(reference) + result = rbf.compute(generated) + print("Computed RBF MMD²:") + for metric, score in result.items(): + print(f" {metric.capitalize()}: {score:.6f}") + print() + + +def calculate_pgs(reference, generated): + """ + Calculate the standard PolyGraphScore between a reference dataset and a generated dataset. + """ + print("PolyGraphScore (StandardPGS)") + pgs = StandardPGS(reference) + result = pgs.compute(generated) + print(f"Overall PGS: {result['polygraphscore']:.6f}") + print(f"Most powerful descriptor: {result['polygraphscore_descriptor']}") + print("Subscores:") + for metric, score in result["subscores"].items(): + print(f" {metric.capitalize()}: {score:.6f}") + print() + + +def calculate_molecule_pgs(ref_smiles, gen_smiles): + """ + Calculate the PolyGraphScore between a reference dataset of molecules and a generated dataset of molecules. + """ + from polygraph.metrics.molecule_pgs import MoleculePGS + import rdkit.Chem + + ref_mols = [rdkit.Chem.MolFromSmiles(smiles) for smiles in ref_smiles] + gen_mols = [rdkit.Chem.MolFromSmiles(smiles) for smiles in gen_smiles] + + print( + f"PolyGraphScore (MoleculePGS) between {len(ref_mols)} reference and {len(gen_mols)} generated molecules:" + ) + pgs = MoleculePGS(ref_mols) + result = pgs.compute(gen_mols) + print(f"Overall MoleculePGS: {result['polygraphscore']:.6f}") + print(f"Most powerful descriptor: {result['polygraphscore_descriptor']}") + print("Subscores:") + for metric, score in result["subscores"].items(): + print(f" {metric.replace('_', ' ').title()}: {score:.6f}") + print() + + +def calculate_vun(reference, generated): + """ + Calculate the VUN between a reference dataset and a generated dataset. + """ + ds = ProceduralPlanarGraphDataset("val", num_graphs=1) + validity_fn = ds.is_valid if reference is not None else None + print("VUN") + vun = VUN(reference, validity_fn=validity_fn) + result = vun.compute(generated) + print("Computed VUN:") + for metric, score in result.items(): + print(f" {metric.replace('_', ' ').title()}: {score:.6f}") + print() + + +def main(): + print("=== PolyGraph Demo ===") + + # Data location-related information + data_location() + reference, generated = get_example_datasets() + print() + + calculate_gtv_mmd(reference, generated) + calculate_rbf_mmd(reference, generated) + calculate_pgs(reference, generated) + calculate_vun(reference, generated) + calculate_molecule_pgs(REF_SMILES, GEN_SMILES) + + print("=== PolyGraph Demo End ===") + + +if __name__ == "__main__": + main() diff --git a/polygraph_demo.py b/polygraph_demo.py deleted file mode 100644 index 5705a164..00000000 --- a/polygraph_demo.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -"""polygraph_demo.py - -In this file, we aim to demonstrate some of the features of the polygraph library. - -""" - -import os -from typing import List - -import networkx as nx -from appdirs import user_cache_dir -from loguru import logger - -import polygraph -from polygraph.datasets import ProceduralPlanarGraphDataset -from polygraph.metrics import ( - VUN, - GaussianTVMMD2Benchmark, - RBFMMD2Benchmark, - StandardPGS, -) - - -def _sample_generated_graphs(n: int, num_nodes: int = 64, start_seed: int = 0) -> List[nx.Graph]: - """Create a small set of Erdos-Renyi graphs as a stand-in for a generator.""" - return [nx.erdos_renyi_graph(num_nodes, 0.1, seed=i + start_seed) for i in range(n)] - -def data_location(): - cache_dir = user_cache_dir(f"polygraph-{polygraph.__version__}", "ANON_ORG") - logger.info(f"PolyGraph cache is typically located at: {cache_dir}") - logger.info("It can be changed by setting the POLYGRAPH_CACHE_DIR environment variable.") - logger.info("Current value: ", os.environ.get("POLYGRAPH_CACHE_DIR")) - -def get_example_datasets(): - """ - Create a small set of Erdos-Renyi graphs as a stand-in for a generator and a reference dataset. - """ - - reference_ds = ProceduralPlanarGraphDataset("val", num_graphs=32).to_nx() - generated = _sample_generated_graphs(32) - logger.info(f"Reference graphs: {len(reference_ds)} | Generated graphs: {len(generated)}") - return reference_ds, generated - -def calculate_gtv_mmd(reference, generated): - """ - Calculate the GTV pseudokernel MMD between a reference dataset and a generated dataset. - """ - logger.info("GaussianTV MMD² Benchmark") - gtv = GaussianTVMMD2Benchmark(reference) - logger.info(f"Computed Gaussian TV pseudokernel MMD²: {gtv.compute(generated)}") - -def calculate_rbf_mmd(reference, generated): - """ - Calculate the RBF MMD between a reference dataset and a generated dataset. - """ - logger.info("RBF MMD² Benchmark") - rbf = RBFMMD2Benchmark(reference) - logger.info(f"Computed RBF MMD²: {rbf.compute(generated)}") - - -def calculate_pgs(reference, generated): - """ - Calculate the PolyGraphScore between a reference dataset and a generated dataset. - """ - logger.info("PolyGraphScore (StandardPGS)") - pgs = StandardPGS(reference) - logger.info(f"Computed PolyGraphScore: {pgs.compute(generated)}") - -def calculate_vun(reference, generated): - """ - Calculate the VUN between a reference dataset and a generated dataset. - """ - ds = ProceduralPlanarGraphDataset("val", num_graphs=1) - validity_fn = ds.is_valid if reference is not None else None - logger.info("VUN") - vun = VUN(reference, validity_fn=validity_fn) - logger.info(f"Computed VUN: {vun.compute(generated)}") - - -def main(): - logger.info("=== PolyGraph Demo ===") - - # Data location-related information - data_location() - reference, generated = get_example_datasets() - - calculate_gtv_mmd(reference, generated) - calculate_rbf_mmd(reference, generated) - calculate_pgs(reference, generated) - calculate_vun(reference, generated) - - logger.success("=== PolyGraph Demo End ===") - -if __name__ == "__main__": - main() diff --git a/tests/test_online_datasets.py b/tests/test_online_datasets.py index 1eeac786..77685dc1 100644 --- a/tests/test_online_datasets.py +++ b/tests/test_online_datasets.py @@ -9,7 +9,6 @@ from tqdm.rich import tqdm from polygraph.datasets import ( - URLGraphDataset, MOSES, QM9, DobsonDoigGraphDataset, @@ -205,14 +204,3 @@ def test_split_disjointness(ds_cls): assert result["unique"] == 1 assert result["novel"] == 1 prev_splits.extend(graphs) - - -@pytest.mark.parametrize("memmap", [True, False]) -def test_url_dataset(memmap): - ds = URLGraphDataset( - "https://datashare.biochem.mpg.de/s/f3kXPP4LICWKbBx/download", - memmap=memmap, - ) - assert len(ds) == 128 - assert isinstance(ds[0], Data) - assert ds.to_nx()[0].number_of_nodes() == 64