11"""
22A base class for global optimization including:
33
4- - WFS
5- - DFS
6- - QRS
4+ - WFS: Width First Sampling
5+ - DFS: Depth First Sampling
6+ - QRS: Quasi Random Sampling
77"""
88from __future__ import annotations
99from multiprocessing import Pool
10-
1110from concurrent .futures import TimeoutError
1211import signal
1312
@@ -37,7 +36,6 @@ def setup_worker_logger(log_file):
3736 filename = log_file ,
3837 level = logging .INFO )
3938
40-
4139# Update run_optimizer_with_timeout to accept a logger
4240def run_optimizer_with_timeout (args , logger ):
4341 """
@@ -481,7 +479,7 @@ def early_termination(self, success_rate):
481479
482480 def export_references (self , xtals , engs , N_min = 50 , dE = 2.5 , FMSE = 2.5 ):
483481 """
484- Add trainning data
482+ Add trainning data for FF optimization
485483
486484 Args:
487485 xtals: a list of pyxtals
@@ -498,19 +496,16 @@ def export_references(self, xtals, engs, N_min=50, dE=2.5, FMSE=2.5):
498496 _xtals = self .select_xtals (xtals , ids , N_max )
499497 print ("Select structures for FF optimization" , len (_xtals ))
500498
501- numMols = [xtal .numMols for xtal in _xtals ]
502- _xtals = [xtal .to_ase (resort = False ) for xtal in _xtals ]
503-
504499 # Initialize references
505500 if os .path .exists (self .reference_file ):
506501 ref_dics = self .parameters .load_references (self .reference_file )
502+ ref_ground_states = self .parameters .get_gs_from_ref_dics (ref_dics )
507503 else :
508504 ref_dics = []
509-
505+ ref_ground_states = []
510506
511507 # Add references
512508 os .chdir (self .workdir )
513-
514509 if len (ref_dics ) > 0 and self .check :
515510 ref_dics = self .parameters .cut_references_by_error (ref_dics ,
516511 params ,
@@ -522,20 +517,14 @@ def export_references(self, xtals, engs, N_min=50, dE=2.5, FMSE=2.5):
522517
523518 t0 = time ()
524519 N_selected = min ([N_min , self .ncpu ])
525- print ("Current number of reference structures" , len (ref_dics ))
526- print ("Create the reference data by augmentation" , N_selected )
527- if len (_xtals ) >= N_selected :
528- ids = self .random_state .choice (list (range (len (_xtals ))), N_selected )
529- _xtals = [_xtals [id ] for id in ids ]
530- numMols = [numMols [id ] for id in ids ]
531-
532- _ref_dics = self .parameters .add_multi_references (_xtals ,
533- numMols ,
534- augment = True ,
535- steps = 20 , #50,
536- N_vibs = 1 ,
537- logfile = "ase.log" )
520+ _ref_dics = self .parameters .add_references (_xtals , ref_ground_states , N_selected )
521+ # print(f"Current number of reference structures: {len(ref_dics)}")
522+ # print(f"Pick {len(_ref_dics)} reference data for agumentation")
523+ #print(_ref_dics); import sys; sys.exit()
524+
538525 ref_dics .extend (_ref_dics )
526+ aug_dics = self .parameters .augment_references (_ref_dics )
527+ ref_dics .extend (aug_dics )
539528 t1 = (time () - t0 ) / 60
540529 print (f"Add { len (_ref_dics )} references in { t1 :.2f} min" )
541530
@@ -584,7 +573,7 @@ def _prepare_chm_info(self, params0, params1=None, folder="calc", suffix="pyxtal
584573 ase_with_ff .write_charmmfiles (base = suffix )
585574 os .chdir (pwd )
586575
587- # Info
576+ # Return the atom_info
588577 return ase_with_ff .get_atom_info ()
589578
590579 def get_label (self , i , label = 'cpu' ):
@@ -724,8 +713,7 @@ def check_ref(self, reps=None, reference=None, filename="pyxtal.cif"):
724713 refernce: [pmg, eng]
725714 filename: filename
726715 """
727- if os .path .exists (filename ):
728- os .remove (filename )
716+ if os .path .exists (filename ): os .remove (filename )
729717
730718 if reference is not None :
731719 [pmg0 , eng ] = reference
@@ -772,6 +760,9 @@ def check_ref(self, reps=None, reference=None, filename="pyxtal.cif"):
772760 return False
773761
774762 def _get_local_optimization_args (self ):
763+ """
764+ Get the arguments for the local optimization
765+ """
775766 args = [
776767 randomizer ,
777768 optimizer ,
@@ -891,7 +882,6 @@ def local_optimization_mproc(self, xtals, ncpu, ids=None, qrs=False, pool=None):
891882 qrs (bool): Force mutation or not (related to QRS)
892883 """
893884 gen = self .generation
894- t0 = time ()
895885 args = self ._get_local_optimization_args ()
896886 if ids is None :
897887 ids = range (len (xtals ))
@@ -1126,8 +1116,6 @@ def save(self, filename):
11261116 #ET.SubElement(basic, "sites").text = str(self.sites)
11271117 #ET.SubElement(basic, "torsions").text = self.torsions
11281118 #ET.SubElement(basic, "ref_criteria").text = str(None) #self.ref_criteria
1129-
1130-
11311119 # Use prettify to get a pretty-printed XML string
11321120 pretty_xml = prettify (root )
11331121 with open (filename , "w" ) as f :
0 commit comments