3131import torch as _torch
3232import numpy as _np
3333
34- from typing import List
34+ from typing import List , Optional
3535
3636from ._emle import EMLE as _EMLE
3737from ._emle import _has_nnpops
@@ -198,12 +198,21 @@ def __init__(
198198 )
199199
200200 if not isinstance (mace_model , (list , tuple )):
201- mace_model = [mace_model ] if mace_model is None or isinstance (mace_model , str ) else None
201+ mace_model = (
202+ [mace_model ]
203+ if mace_model is None or isinstance (mace_model , str )
204+ else None
205+ )
202206
203- if mace_model is None or any (not isinstance (i , (str , type (None ))) for i in mace_model ):
204- raise TypeError ("'mace_model' must be a list, tuple, or str, with elements of type str or None" )
207+ if mace_model is None or any (
208+ not isinstance (i , (str , type (None ))) for i in mace_model
209+ ):
210+ raise TypeError (
211+ "'mace_model' must be a list, tuple, or str, with elements of type str or None"
212+ )
205213
206214 from mace .tools .scripts_utils import extract_config_mace_model
215+
207216 self ._mace_models = _torch .nn .ModuleList ()
208217 for model in mace_model :
209218 source_model = self ._load_mace_model (model , device )
@@ -392,9 +401,9 @@ def to(self, *args, **kwargs):
392401 """
393402 self ._emle = self ._emle .to (* args , ** kwargs )
394403 self ._mace = self ._mace .to (* args , ** kwargs )
395- self ._mace_models = _torch .nn .ModuleList ([
396- model .to (* args , ** kwargs ) for model in self ._mace_models
397- ] )
404+ self ._mace_models = _torch .nn .ModuleList (
405+ [ model .to (* args , ** kwargs ) for model in self ._mace_models ]
406+ )
398407 return self
399408
400409 def cpu (self , ** kwargs ):
@@ -405,9 +414,9 @@ def cpu(self, **kwargs):
405414 self ._mace = self ._mace .cpu (** kwargs )
406415 if self ._atomic_numbers is not None :
407416 self ._atomic_numbers = self ._atomic_numbers .cpu (** kwargs )
408- self ._mace_models = _torch .nn .ModuleList ([
409- model .cpu (** kwargs ) for model in self ._mace_models
410- ] )
417+ self ._mace_models = _torch .nn .ModuleList (
418+ [ model .cpu (** kwargs ) for model in self ._mace_models ]
419+ )
411420 return self
412421
413422 def cuda (self , ** kwargs ):
@@ -418,9 +427,9 @@ def cuda(self, **kwargs):
418427 self ._mace = self ._mace .cuda (** kwargs )
419428 if self ._atomic_numbers is not None :
420429 self ._atomic_numbers = self ._atomic_numbers .cuda (** kwargs )
421- self ._mace_models = _torch .nn .ModuleList ([
422- model .cuda (** kwargs ) for model in self ._mace_models
423- ] )
430+ self ._mace_models = _torch .nn .ModuleList (
431+ [ model .cuda (** kwargs ) for model in self ._mace_models ]
432+ )
424433 return self
425434
426435 def double (self ):
@@ -429,9 +438,9 @@ def double(self):
429438 """
430439 self ._emle = self ._emle .double ()
431440 self ._mace = self ._mace .double ()
432- self ._mace_models = _torch .nn .ModuleList ([
433- model .double () for model in self ._mace_models
434- ] )
441+ self ._mace_models = _torch .nn .ModuleList (
442+ [ model .double () for model in self ._mace_models ]
443+ )
435444 return self
436445
437446 def float (self ):
@@ -440,9 +449,9 @@ def float(self):
440449 """
441450 self ._emle = self ._emle .float ()
442451 self ._mace = self ._mace .float ()
443- self ._mace_models = _torch .nn .ModuleList ([
444- model .float () for model in self ._mace_models
445- ] )
452+ self ._mace_models = _torch .nn .ModuleList (
453+ [ model .float () for model in self ._mace_models ]
454+ )
446455 return self
447456
448457 def forward (
@@ -451,6 +460,7 @@ def forward(
451460 charges_mm : Tensor ,
452461 xyz_qm : Tensor ,
453462 xyz_mm : Tensor ,
463+ cell : Optional [Tensor ] = None ,
454464 qm_charge : int = 0 ,
455465 ) -> Tensor :
456466 """
@@ -471,6 +481,9 @@ def forward(
471481 xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3)
472482 Positions of MM atoms in Angstrom.
473483
484+ cell: torch.Tensor (3, 3) or (BATCH, 3, 3), optional
485+ The simulation cell vectors in Angstrom.
486+
474487 qm_charge: int
475488 The charge on the QM region.
476489
@@ -489,6 +502,8 @@ def forward(
489502 xyz_qm = xyz_qm .unsqueeze (0 )
490503 xyz_mm = xyz_mm .unsqueeze (0 )
491504 charges_mm = charges_mm .unsqueeze (0 )
505+ if cell is not None and cell .ndim == 2 :
506+ cell = cell .unsqueeze (0 )
492507
493508 # Store the number of batches.
494509 num_batches = atomic_numbers .shape [0 ]
@@ -497,8 +512,17 @@ def forward(
497512 num_models = len (self ._mace_models )
498513
499514 # Create tensors to store the data for QbC.
500- self ._E_vac_qbc = _torch .empty (num_models , num_batches , dtype = self ._dtype , device = device )
501- self ._grads_qbc = _torch .empty (num_models , num_batches , xyz_qm .shape [1 ], 3 , dtype = self ._dtype , device = device )
515+ self ._E_vac_qbc = _torch .empty (
516+ num_models , num_batches , dtype = self ._dtype , device = device
517+ )
518+ self ._grads_qbc = _torch .empty (
519+ num_models ,
520+ num_batches ,
521+ xyz_qm .shape [1 ],
522+ 3 ,
523+ dtype = self ._dtype ,
524+ device = device ,
525+ )
502526
503527 # Create tensors to store the results.
504528 results_E_vac = _torch .empty (num_batches , dtype = self ._dtype , device = device )
@@ -516,6 +540,10 @@ def forward(
516540 xyz_qm [i ], None , self ._r_max , self ._dtype , device
517541 )
518542
543+ # Get the cell for this configuration.
544+ if cell is not None :
545+ cell = cell .to (self ._dtype ).to (device )
546+
519547 if not _torch .equal (atomic_numbers [i ], self ._atomic_numbers ):
520548 # Update the node attributes if the atomic numbers have changed.
521549 self ._node_attrs = (
@@ -557,19 +585,25 @@ def forward(
557585 results_E_vac [i ] = E_vac [0 ] * EV_TO_HARTREE
558586
559587 # Decouple the positions from the computation graph for the next models.
560- input_dict ["positions" ] = input_dict ["positions" ].clone ().detach ().requires_grad_ (True )
588+ input_dict ["positions" ] = (
589+ input_dict ["positions" ].clone ().detach ().requires_grad_ (True )
590+ )
561591
562592 # Do inference for the other models.
563593 if len (self ._mace_models ) > 1 :
564594 for j , mace in enumerate (self ._mace_models ):
565- E_vac_qbc = mace (input_dict , compute_force = False )["interaction_energy" ]
595+ E_vac_qbc = mace (input_dict , compute_force = False )[
596+ "interaction_energy"
597+ ]
566598
567599 assert (
568600 E_vac_qbc is not None
569601 ), "The model did not return any energy. Please check the input."
570602
571603 # Calculate the gradients
572- grads_qbc = _torch .autograd .grad ([E_vac_qbc ], [input_dict ["positions" ]])[0 ]
604+ grads_qbc = _torch .autograd .grad (
605+ [E_vac_qbc ], [input_dict ["positions" ]]
606+ )[0 ]
573607 assert grads_qbc is not None , "Gradient computation failed"
574608
575609 # Store the results.
@@ -585,7 +619,7 @@ def forward(
585619 else :
586620 # Get the EMLE energy components.
587621 E_emle = self ._emle (
588- atomic_numbers , charges_mm , xyz_qm , xyz_mm , qm_charge
622+ atomic_numbers , charges_mm , xyz_qm , xyz_mm , cell , qm_charge
589623 )
590624 results_E_emle_static [i ] = E_emle [0 ][0 ]
591625 results_E_emle_induced [i ] = E_emle [1 ][0 ]
0 commit comments