Skip to content

Commit ddb9559

Browse files
authored
Merge pull request #97 from diningphil/store-metrics-over-time
added option to store metrics over time in a file
2 parents b8aedef + d0bc0c0 commit ddb9559

5 files changed

Lines changed: 47 additions & 4 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changelog
22

3+
## [1.5.6] Storing metrics on disk
4+
5+
### Added
6+
7+
- You can now store metrics trend across epochs using `Plotter`. Just pass the argument `store_on_disk=True` in the configuration file of the experiment.
8+
39
## [1.5.5]
410

511
### Fixed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
author = "Federico Errica"
2525

2626
# The full version, including alpha/beta/rc tags
27-
release = "1.5.5"
27+
release = "1.5.6"
2828

2929

3030
# -- General configuration ---------------------------------------------------

examples/MODEL_CONFIGS/config_SupToyDGN.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,7 @@ grid:
119119
checkpoint: True # store the best checkpoint
120120

121121
# Plotter of metrics
122-
plotter: pydgn.training.callback.plotter.Plotter
122+
plotter:
123+
- class_name: pydgn.training.callback.plotter.Plotter
124+
args:
125+
store_on_disk: True

pydgn/training/callback/plotter.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from pathlib import Path
33

4+
import torch
45
from torch.utils.tensorboard import SummaryWriter
56

67
from pydgn.static import *
@@ -14,17 +15,29 @@ class Plotter(EventHandler):
1415
1516
Args:
1617
exp_path (str): path where to store the Tensorboard logs
18+
store_on_disk (bool): whether to store all metrics on disk.
19+
Defaults to False
1720
kwargs (dict): additional arguments that may depend on the plotter
1821
"""
1922

20-
def __init__(self, exp_path: str, **kwargs: dict):
23+
def __init__(self,
24+
exp_path: str,
25+
store_on_disk: bool = False,
26+
**kwargs: dict):
2127
super().__init__()
2228
self.exp_path = exp_path
29+
self.store_on_disk = store_on_disk
2330

2431
if not os.path.exists(Path(self.exp_path, TENSORBOARD)):
2532
os.makedirs(Path(self.exp_path, TENSORBOARD))
2633
self.writer = SummaryWriter(log_dir=Path(self.exp_path, "tensorboard"))
2734

35+
self.stored_metrics = {"losses": {}, "scores": {}}
36+
self.stored_metrics_path = Path(self.exp_path, "metrics_data.torch")
37+
if os.path.exists(self.stored_metrics_path):
38+
self.stored_metrics = torch.load(self.stored_metrics_path)
39+
40+
2841
def on_epoch_end(self, state: State):
2942
"""
3043
Writes Training, Validation and (if any) Test metrics to Tensorboard
@@ -46,6 +59,13 @@ def on_epoch_end(self, state: State):
4659

4760
self.writer.add_scalars(loss_name, loss_scalars, state.epoch)
4861

62+
if self.store_on_disk:
63+
t = "losses"
64+
if not k in self.stored_metrics[t]:
65+
self.stored_metrics[t][k] = [v.item()]
66+
else:
67+
self.stored_metrics[t][k].append(v.item())
68+
4969
for k, v in state.epoch_results[SCORES].items():
5070
score_scalars = {}
5171
# Remove training/validation/test prefix (coupling with Engine)
@@ -59,6 +79,20 @@ def on_epoch_end(self, state: State):
5979

6080
self.writer.add_scalars(score_name, score_scalars, state.epoch)
6181

82+
if self.store_on_disk:
83+
t = "scores"
84+
if not k in self.stored_metrics[t]:
85+
self.stored_metrics[t][k] = [v.item()]
86+
else:
87+
self.stored_metrics[t][k].append(v.item())
88+
89+
if self.store_on_disk:
90+
try:
91+
torch.save(self.stored_metrics, self.stored_metrics_path)
92+
except RuntimeError as e:
93+
print(e)
94+
95+
6296
def on_fit_end(self, state: State):
6397
"""
6498
Frees resources by closing the Tensorboard writer

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "pydgn"
7-
version = "1.5.5"
7+
version = "1.5.6"
88
description = "A Python Package for Deep Graph Networks"
99
authors = [ { name="Federico Errica", email="[email protected]" } ]
1010
readme = "README.md"

0 commit comments

Comments
 (0)