Skip to content

Commit bc1c2ab

Browse files
authored
Merge pull request #118 from OpenBioSim/fix_trajectory_issue
Fix trajectory indexing issue on restart
2 parents 0e605cf + 66354b8 commit bc1c2ab

File tree

3 files changed

+38
-16
lines changed

3 files changed

+38
-16
lines changed

src/somd2/runner/_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,9 @@ def _checkpoint(
17621762
)
17631763
system.set_property("lambda", lam)
17641764

1765+
# Delete all frames from the system.
1766+
system.delete_all_frames()
1767+
17651768
# Stream the final system to file.
17661769
_sr.stream.save(system, self._filenames[index]["checkpoint"])
17671770

@@ -1796,6 +1799,9 @@ def _checkpoint(
17961799
)
17971800
system.set_property("lambda", lam)
17981801

1802+
# Delete all frames from the system.
1803+
system.delete_all_frames()
1804+
17991805
# Stream the checkpoint to file.
18001806
_sr.stream.save(system, self._filenames[index]["checkpoint"])
18011807

src/somd2/runner/_repex.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ def _create_dynamics(
247247
else:
248248
mols = system
249249

250+
# Delete an existing trajectory frames.
251+
mols.delete_all_frames()
252+
250253
# Overload the device and lambda value.
251254
dynamics_kwargs["device"] = device
252255
dynamics_kwargs["lambda_value"] = lam
@@ -721,11 +724,17 @@ def __init__(self, system, config):
721724
output_directory=self._config.output_directory,
722725
)
723726
else:
727+
_logger.debug("Restarting from file")
728+
724729
# Check to see if the simulation is already complete.
725730
time = self._system[0].time()
726731
if time > self._config.runtime - self._config.timestep:
727-
_logger.success(f"Simulation already complete. Exiting.")
732+
_logger.success("Simulation already complete. Exiting.")
728733
_sys.exit(0)
734+
else:
735+
_logger.info(
736+
f"Restarting at time {time}, time remaining = {self._config.runtime - time}"
737+
)
729738

730739
try:
731740
with open(self._repex_state, "rb") as f:
@@ -827,28 +836,28 @@ def run(self):
827836
else:
828837
cycles = int(ceil(cycles))
829838

830-
if self._config.checkpoint_frequency.value() > 0.0:
839+
# Store the current checkpoint frequency.
840+
checkpoint_frequency = self._config.checkpoint_frequency
841+
842+
if checkpoint_frequency.value() > 0.0:
831843
# Calculate the number of blocks and the remainder time.
832-
frac = (self._config.runtime / self._config.checkpoint_frequency).value()
844+
frac = (self._config.runtime / checkpoint_frequency).value()
833845

834846
# Handle the case where the runtime is less than the checkpoint frequency.
835847
if frac < 1.0:
836848
frac = 1.0
837-
self._config.checkpoint_frequency = str(self._config.runtime)
849+
checkpoint_frequency = self._config.runtime
838850

839851
num_blocks = int(frac)
840852
rem = round(frac - num_blocks, 12)
841853

842854
# Work out the number of repex cycles per block.
843-
frac = (
844-
self._config.checkpoint_frequency.value()
845-
/ self._config.energy_frequency.value()
846-
)
855+
frac = (checkpoint_frequency / self._config.energy_frequency).value()
847856

848857
# Handle the case where the checkpoint frequency is less than the energy frequency.
849858
if frac < 1.0:
850859
frac = 1.0
851-
self._config.checkpoint_frequency = str(self._config.energy_frequency)
860+
checkpoint_frequency = self._config.energy_frequency
852861

853862
# Store the number of repex cycles per block.
854863
cycles_per_checkpoint = int(frac)
@@ -1035,7 +1044,7 @@ def run(self):
10351044
repeat(num_blocks + int(rem > 0)),
10361045
repeat(i == cycles - 1),
10371046
):
1038-
if not result:
1047+
if error:
10391048
_logger.error(
10401049
f"Checkpoint failed for {_lam_sym} = "
10411050
f"{self._lambda_values[index]:.5f}: {error}"

src/somd2/runner/_runner.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def run_window(self, index):
292292
else:
293293
system = self._system.clone()
294294

295+
# Delete an existing trajectory frames.
296+
system.delete_all_frames()
297+
295298
# GPU platform.
296299
if self._is_gpu:
297300
# Get a GPU from the pool.
@@ -644,22 +647,26 @@ def generate_lam_vals(lambda_base, increment=0.001):
644647
else:
645648
num_energy_neighbours = None
646649

650+
# Store the current checkpoint frequency.
651+
checkpoint_frequency = self._config.checkpoint_frequency
652+
647653
# Store the checkpoint time in nanoseconds.
648-
checkpoint_interval = self._config.checkpoint_frequency.to("ns")
654+
checkpoint_interval = checkpoint_frequency.to("ns")
649655

650656
# Store the start time.
651657
start = _timer()
652658

653659
# Run the simulation, checkpointing in blocks.
654-
if self._config.checkpoint_frequency.value() > 0.0:
660+
if checkpoint_frequency.value() > 0.0:
655661

656662
# Calculate the number of blocks and the remainder time.
657-
frac = (time / self._config.checkpoint_frequency).value()
663+
frac = (time / checkpoint_frequency).value()
658664

659665
# Handle the case where the runtime is less than the checkpoint frequency.
660666
if frac < 1.0:
661667
frac = 1.0
662-
self._config.checkpoint_frequency = f"{time} ps"
668+
checkpoint_frequency = _sr.u(f"{time} ps")
669+
checkpoint_interval = checkpoint_frequency.to("ns")
663670

664671
num_blocks = int(frac)
665672
rem = round(frac - num_blocks, 12)
@@ -684,7 +691,7 @@ def generate_lam_vals(lambda_base, increment=0.001):
684691
next_frame = self._config.frame_frequency
685692

686693
# Loop until we reach the runtime.
687-
while runtime <= self._config.checkpoint_frequency:
694+
while runtime <= checkpoint_frequency:
688695
# Run the dynamics in blocks of the GCMC frequency.
689696
dynamics.run(
690697
self._config.gcmc_frequency,
@@ -725,7 +732,7 @@ def generate_lam_vals(lambda_base, increment=0.001):
725732

726733
else:
727734
dynamics.run(
728-
self._config.checkpoint_frequency,
735+
checkpoint_frequency,
729736
energy_frequency=self._config.energy_frequency,
730737
frame_frequency=self._config.frame_frequency,
731738
lambda_windows=lambda_array,

0 commit comments

Comments
 (0)