|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +"""polygraph_demo.py |
| 5 | +
|
| 6 | +In this file, we aim to demonstrate some of the features of the polygraph library. |
| 7 | +
|
| 8 | +""" |
| 9 | + |
| 10 | +import os |
| 11 | +from typing import List |
| 12 | +import warnings |
| 13 | + |
| 14 | +import networkx as nx |
| 15 | +from appdirs import user_cache_dir |
| 16 | +from loguru import logger |
| 17 | + |
| 18 | +import polygraph |
| 19 | +from polygraph.datasets import ProceduralPlanarGraphDataset |
| 20 | +from polygraph.metrics import ( |
| 21 | + VUN, |
| 22 | + GaussianTVMMD2Benchmark, |
| 23 | + RBFMMD2Benchmark, |
| 24 | + StandardPGS, |
| 25 | +) |
| 26 | + |
| 27 | + |
| 28 | +REF_SMILES = [ |
| 29 | + "Nc1ncnc2c1ncn2C1OC(CO)CC1F", |
| 30 | + "C=CCc1c(OC(C)=O)c2cccnc2n(-c2ccccc2)c1=O", |
| 31 | + "COc1ccc(Cc2cnc(N)nc2N)cc1OC", |
| 32 | + "COc1cc(O)cc(CCc2ccc(O)c(OC)c2)c1", |
| 33 | + "COc1c(C)cnc(CSc2nccn2C)c1C", |
| 34 | + "O=c1cc(-c2ccncc2)nc(-c2cccnc2)[nH]1", |
| 35 | + "O=c1c2ccccc2oc2nc3n(c(=O)c12)CCCS3", |
| 36 | + "O=c1c2cc(Cl)ccc2oc2nc3n(c(=O)c12)CCCS3", |
| 37 | +] |
| 38 | +GEN_SMILES = [ |
| 39 | + "O=C(NC1CCN(C(=O)C2CC2)CC1)c1ccc(F)cc1", |
| 40 | + "NC(=O)c1cccc2[nH]c(-c3ccc(O)cc3)nc12", |
| 41 | + "CC(C)CCNC(=O)c1c[nH]c2ccccc2c1=O", |
| 42 | + "CCOc1ccc2[nH]cc(C(=O)NCc3cccnc3)c(=O)c2c1", |
| 43 | + "O=C(NCc1ccccc1)c1c[nH]c2c(F)cccc2c1=O", |
| 44 | + "CC(C)c1cccc(C(C)C)c1NCc1ccccn1", |
| 45 | + "CC1CCC(NC(=O)c2cc3ccccc3o2)CC1", |
| 46 | + "COc1ccc2[nH]cc(CCNC(=O)c3ccco3)c2c1", |
| 47 | +] |
| 48 | + |
| 49 | +logger.disable("polygraph") |
| 50 | +warnings.filterwarnings("ignore", category=UserWarning) |
| 51 | +warnings.filterwarnings("ignore", category=FutureWarning) |
| 52 | + |
| 53 | + |
| 54 | +def _sample_generated_graphs( |
| 55 | + n: int, num_nodes: int = 64, start_seed: int = 0 |
| 56 | +) -> List[nx.Graph]: |
| 57 | + """Create a small set of Erdos-Renyi graphs as a stand-in for a generator.""" |
| 58 | + return [ |
| 59 | + nx.erdos_renyi_graph(num_nodes, 0.1, seed=i + start_seed) |
| 60 | + for i in range(n) |
| 61 | + ] |
| 62 | + |
| 63 | + |
| 64 | +def data_location(): |
| 65 | + cache_dir = user_cache_dir(f"polygraph-{polygraph.__version__}", "ANON_ORG") |
| 66 | + print(f"PolyGraph cache is typically located at: {cache_dir}") |
| 67 | + print( |
| 68 | + "It can be changed by setting the POLYGRAPH_CACHE_DIR environment variable." |
| 69 | + ) |
| 70 | + print("Current value: ", os.environ.get("POLYGRAPH_CACHE_DIR")) |
| 71 | + |
| 72 | + |
| 73 | +def get_example_datasets(): |
| 74 | + """ |
| 75 | + Create a small set of Erdos-Renyi graphs as a stand-in for a generator and a reference dataset. |
| 76 | + """ |
| 77 | + |
| 78 | + reference_ds = list( |
| 79 | + ProceduralPlanarGraphDataset("val", num_graphs=32).to_nx() |
| 80 | + ) |
| 81 | + generated = _sample_generated_graphs(32) |
| 82 | + print( |
| 83 | + f"Reference graphs: {len(reference_ds)} | Generated graphs: {len(generated)}" |
| 84 | + ) |
| 85 | + return reference_ds, generated |
| 86 | + |
| 87 | + |
| 88 | +def calculate_gtv_mmd(reference, generated): |
| 89 | + """ |
| 90 | + Calculate the GTV pseudokernel MMD between a reference dataset and a generated dataset. |
| 91 | + """ |
| 92 | + print("GaussianTV MMD² Benchmark") |
| 93 | + gtv = GaussianTVMMD2Benchmark(reference) |
| 94 | + result = gtv.compute(generated) |
| 95 | + print("Computed Gaussian TV pseudokernel MMD²:") |
| 96 | + for metric, score in result.items(): |
| 97 | + print(f" {metric.capitalize()}: {score:.6f}") |
| 98 | + print() |
| 99 | + |
| 100 | + |
| 101 | +def calculate_rbf_mmd(reference, generated): |
| 102 | + """ |
| 103 | + Calculate the RBF MMD between a reference dataset and a generated dataset. |
| 104 | + """ |
| 105 | + print("RBF MMD² Benchmark") |
| 106 | + rbf = RBFMMD2Benchmark(reference) |
| 107 | + result = rbf.compute(generated) |
| 108 | + print("Computed RBF MMD²:") |
| 109 | + for metric, score in result.items(): |
| 110 | + print(f" {metric.capitalize()}: {score:.6f}") |
| 111 | + print() |
| 112 | + |
| 113 | + |
| 114 | +def calculate_pgs(reference, generated): |
| 115 | + """ |
| 116 | + Calculate the standard PolyGraphScore between a reference dataset and a generated dataset. |
| 117 | + """ |
| 118 | + print("PolyGraphScore (StandardPGS)") |
| 119 | + pgs = StandardPGS(reference) |
| 120 | + result = pgs.compute(generated) |
| 121 | + print(f"Overall PGS: {result['polygraphscore']:.6f}") |
| 122 | + print(f"Most powerful descriptor: {result['polygraphscore_descriptor']}") |
| 123 | + print("Subscores:") |
| 124 | + for metric, score in result["subscores"].items(): |
| 125 | + print(f" {metric.capitalize()}: {score:.6f}") |
| 126 | + print() |
| 127 | + |
| 128 | + |
| 129 | +def calculate_molecule_pgs(ref_smiles, gen_smiles): |
| 130 | + """ |
| 131 | + Calculate the PolyGraphScore between a reference dataset of molecules and a generated dataset of molecules. |
| 132 | + """ |
| 133 | + from polygraph.metrics.molecule_pgs import MoleculePGS |
| 134 | + import rdkit.Chem |
| 135 | + |
| 136 | + ref_mols = [rdkit.Chem.MolFromSmiles(smiles) for smiles in ref_smiles] |
| 137 | + gen_mols = [rdkit.Chem.MolFromSmiles(smiles) for smiles in gen_smiles] |
| 138 | + |
| 139 | + print( |
| 140 | + f"PolyGraphScore (MoleculePGS) between {len(ref_mols)} reference and {len(gen_mols)} generated molecules:" |
| 141 | + ) |
| 142 | + pgs = MoleculePGS(ref_mols) |
| 143 | + result = pgs.compute(gen_mols) |
| 144 | + print(f"Overall MoleculePGS: {result['polygraphscore']:.6f}") |
| 145 | + print(f"Most powerful descriptor: {result['polygraphscore_descriptor']}") |
| 146 | + print("Subscores:") |
| 147 | + for metric, score in result["subscores"].items(): |
| 148 | + print(f" {metric.replace('_', ' ').title()}: {score:.6f}") |
| 149 | + print() |
| 150 | + |
| 151 | + |
| 152 | +def calculate_vun(reference, generated): |
| 153 | + """ |
| 154 | + Calculate the VUN between a reference dataset and a generated dataset. |
| 155 | + """ |
| 156 | + ds = ProceduralPlanarGraphDataset("val", num_graphs=1) |
| 157 | + validity_fn = ds.is_valid if reference is not None else None |
| 158 | + print("VUN") |
| 159 | + vun = VUN(reference, validity_fn=validity_fn) |
| 160 | + result = vun.compute(generated) |
| 161 | + print("Computed VUN:") |
| 162 | + for metric, score in result.items(): |
| 163 | + print(f" {metric.replace('_', ' ').title()}: {score:.6f}") |
| 164 | + print() |
| 165 | + |
| 166 | + |
| 167 | +def main(): |
| 168 | + print("=== PolyGraph Demo ===") |
| 169 | + |
| 170 | + # Data location-related information |
| 171 | + data_location() |
| 172 | + reference, generated = get_example_datasets() |
| 173 | + print() |
| 174 | + |
| 175 | + calculate_gtv_mmd(reference, generated) |
| 176 | + calculate_rbf_mmd(reference, generated) |
| 177 | + calculate_pgs(reference, generated) |
| 178 | + calculate_vun(reference, generated) |
| 179 | + calculate_molecule_pgs(REF_SMILES, GEN_SMILES) |
| 180 | + |
| 181 | + print("=== PolyGraph Demo End ===") |
| 182 | + |
| 183 | + |
| 184 | +if __name__ == "__main__": |
| 185 | + main() |
0 commit comments