Skip to content

Commit 0adbcfd

Browse files
committed
update optimize code for ff training
1 parent c3c5449 commit 0adbcfd

File tree

4 files changed

+24
-45
lines changed

4 files changed

+24
-45
lines changed

pyxtal/optimize/DFS.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
"""
2-
DFS sampler
2+
Global optimization using Depth First Sampling
33
"""
4-
54
from __future__ import annotations
6-
75
from time import time
86
from typing import TYPE_CHECKING
97

108
import numpy as np
119
from numpy.random import Generator
1210
from pymatgen.analysis.structure_matcher import StructureMatcher
13-
1411
from pyxtal.optimize.base import GlobalOptimize
1512

1613
if TYPE_CHECKING:
@@ -382,4 +379,4 @@ def load(cls, filename):
382379
strs = "Final {:8s} [{:2d}]{:10s} ".format(name, sum(xtal.numMols), spg)
383380
strs += "{:3d}m {:2d} {:6.1f}".format(t1, N_torsion, wt)
384381
strs += "{:12.3f} {:20s} {:s}".format(eng, mytag, smile)
385-
print(strs)
382+
print(strs)

pyxtal/optimize/QRS.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
"""
2-
Global Optimizer
2+
Global Optimizer base on Quasi-Random Sampling
33
"""
4-
54
from __future__ import annotations
6-
75
from time import time
86
from typing import TYPE_CHECKING
97

108
import numpy as np
119
from scipy.stats import qmc
12-
13-
from numpy.random import Generator
1410
from pymatgen.analysis.structure_matcher import StructureMatcher
1511

1612
from pyxtal.optimize.base import GlobalOptimize
@@ -318,4 +314,4 @@ def _run(self, pool=None):
318314
return success_rate
319315

320316
if __name__ == "__main__":
321-
print("test")
317+
print("test")

pyxtal/optimize/WFS.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
"""
2-
WFS sampler
2+
Global optimization using the Stochastic Width First Sampling (WFS) algorithm
33
"""
44

55
from __future__ import annotations
6-
76
from time import time
87
from typing import TYPE_CHECKING
98

109
import numpy as np
1110
from numpy.random import Generator
1211
from pymatgen.analysis.structure_matcher import StructureMatcher
13-
1412
from pyxtal.optimize.base import GlobalOptimize
1513

1614
if TYPE_CHECKING:
@@ -369,4 +367,4 @@ def load(cls, filename):
369367
strs = "Final {:8s} [{:2d}]{:10s} ".format(name, sum(xtal.numMols), spg)
370368
strs += "{:3d}m {:2d} {:6.1f}".format(t1, N_torsion, wt)
371369
strs += "{:12.3f} {:20s} {:s}".format(eng, mytag, smile)
372-
print(strs)
370+
print(strs)

pyxtal/optimize/base.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
"""
22
A base class for global optimization including:
33
4-
- WFS
5-
- DFS
6-
- QRS
4+
- WFS: Width First Sampling
5+
- DFS: Depth First Sampling
6+
- QRS: Quasi Random Sampling
77
"""
88
from __future__ import annotations
99
from multiprocessing import Pool
10-
1110
from concurrent.futures import TimeoutError
1211
import signal
1312

@@ -37,7 +36,6 @@ def setup_worker_logger(log_file):
3736
filename=log_file,
3837
level=logging.INFO)
3938

40-
4139
# Update run_optimizer_with_timeout to accept a logger
4240
def run_optimizer_with_timeout(args, logger):
4341
"""
@@ -481,7 +479,7 @@ def early_termination(self, success_rate):
481479

482480
def export_references(self, xtals, engs, N_min=50, dE=2.5, FMSE=2.5):
483481
"""
484-
Add trainning data
482+
Add trainning data for FF optimization
485483
486484
Args:
487485
xtals: a list of pyxtals
@@ -498,19 +496,16 @@ def export_references(self, xtals, engs, N_min=50, dE=2.5, FMSE=2.5):
498496
_xtals = self.select_xtals(xtals, ids, N_max)
499497
print("Select structures for FF optimization", len(_xtals))
500498

501-
numMols = [xtal.numMols for xtal in _xtals]
502-
_xtals = [xtal.to_ase(resort=False) for xtal in _xtals]
503-
504499
# Initialize references
505500
if os.path.exists(self.reference_file):
506501
ref_dics = self.parameters.load_references(self.reference_file)
502+
ref_ground_states = self.parameters.get_gs_from_ref_dics(ref_dics)
507503
else:
508504
ref_dics = []
509-
505+
ref_ground_states = []
510506

511507
# Add references
512508
os.chdir(self.workdir)
513-
514509
if len(ref_dics) > 0 and self.check:
515510
ref_dics = self.parameters.cut_references_by_error(ref_dics,
516511
params,
@@ -522,20 +517,14 @@ def export_references(self, xtals, engs, N_min=50, dE=2.5, FMSE=2.5):
522517

523518
t0 = time()
524519
N_selected = min([N_min, self.ncpu])
525-
print("Current number of reference structures", len(ref_dics))
526-
print("Create the reference data by augmentation", N_selected)
527-
if len(_xtals) >= N_selected:
528-
ids = self.random_state.choice(list(range(len(_xtals))), N_selected)
529-
_xtals = [_xtals[id] for id in ids]
530-
numMols = [numMols[id] for id in ids]
531-
532-
_ref_dics = self.parameters.add_multi_references(_xtals,
533-
numMols,
534-
augment=True,
535-
steps=20, #50,
536-
N_vibs=1,
537-
logfile="ase.log")
520+
_ref_dics = self.parameters.add_references(_xtals, ref_ground_states, N_selected)
521+
# print(f"Current number of reference structures: {len(ref_dics)}")
522+
# print(f"Pick {len(_ref_dics)} reference data for agumentation")
523+
#print(_ref_dics); import sys; sys.exit()
524+
538525
ref_dics.extend(_ref_dics)
526+
aug_dics = self.parameters.augment_references(_ref_dics)
527+
ref_dics.extend(aug_dics)
539528
t1 = (time() - t0) / 60
540529
print(f"Add {len(_ref_dics)} references in {t1:.2f} min")
541530

@@ -584,7 +573,7 @@ def _prepare_chm_info(self, params0, params1=None, folder="calc", suffix="pyxtal
584573
ase_with_ff.write_charmmfiles(base=suffix)
585574
os.chdir(pwd)
586575

587-
# Info
576+
# Return the atom_info
588577
return ase_with_ff.get_atom_info()
589578

590579
def get_label(self, i, label='cpu'):
@@ -724,8 +713,7 @@ def check_ref(self, reps=None, reference=None, filename="pyxtal.cif"):
724713
refernce: [pmg, eng]
725714
filename: filename
726715
"""
727-
if os.path.exists(filename):
728-
os.remove(filename)
716+
if os.path.exists(filename): os.remove(filename)
729717

730718
if reference is not None:
731719
[pmg0, eng] = reference
@@ -772,6 +760,9 @@ def check_ref(self, reps=None, reference=None, filename="pyxtal.cif"):
772760
return False
773761

774762
def _get_local_optimization_args(self):
763+
"""
764+
Get the arguments for the local optimization
765+
"""
775766
args = [
776767
randomizer,
777768
optimizer,
@@ -891,7 +882,6 @@ def local_optimization_mproc(self, xtals, ncpu, ids=None, qrs=False, pool=None):
891882
qrs (bool): Force mutation or not (related to QRS)
892883
"""
893884
gen = self.generation
894-
t0 = time()
895885
args = self._get_local_optimization_args()
896886
if ids is None:
897887
ids = range(len(xtals))
@@ -1126,8 +1116,6 @@ def save(self, filename):
11261116
#ET.SubElement(basic, "sites").text = str(self.sites)
11271117
#ET.SubElement(basic, "torsions").text = self.torsions
11281118
#ET.SubElement(basic, "ref_criteria").text = str(None) #self.ref_criteria
1129-
1130-
11311119
# Use prettify to get a pretty-printed XML string
11321120
pretty_xml = prettify(root)
11331121
with open(filename, "w") as f:

0 commit comments

Comments
 (0)