Skip to content

Commit c7f4e3c

Browse files
committed
Add support for passing cell kwarg to EMLE forward method.
1 parent 891a4e9 commit c7f4e3c

4 files changed

Lines changed: 100 additions & 32 deletions

File tree

emle/calculator.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,7 @@ def _calculate_energy_and_gradients(
11311131
xyz_qm,
11321132
xyz_mm,
11331133
atoms=None,
1134+
cell=None,
11341135
charge=0,
11351136
):
11361137
"""
@@ -1154,6 +1155,9 @@ def _calculate_energy_and_gradients(
11541155
atoms: ase.Atoms
11551156
The atoms object for the QM region.
11561157
1158+
cell: numpy.ndarray, (3, 3)
1159+
The simulation cell vectors.
1160+
11571161
charge: int
11581162
The total charge of the QM region.
11591163
@@ -1249,6 +1253,8 @@ def _calculate_energy_and_gradients(
12491253
xyz_mm = _torch.tensor(
12501254
xyz_mm, dtype=_torch.float32, device=self._device, requires_grad=True
12511255
)
1256+
if cell is not None:
1257+
cell = _torch.tensor(cell, dtype=_torch.float32, device=self._device)
12521258

12531259
# Are there any MM atoms?
12541260
allow_unused = len(charges_mm) == 0
@@ -1263,7 +1269,7 @@ def _calculate_energy_and_gradients(
12631269

12641270
# Compute the energy.
12651271
E = delta_model(
1266-
atomic_numbers, null_charges_mm, xyz_qm, null_xyz_mm, charge
1272+
atomic_numbers, null_charges_mm, xyz_qm, null_xyz_mm, cell, charge
12671273
)
12681274

12691275
# Compute the gradients.
@@ -1286,7 +1292,9 @@ def _calculate_energy_and_gradients(
12861292
if base_model is None:
12871293
try:
12881294
if len(xyz_mm) > 0:
1289-
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, charge)
1295+
E = self._emle(
1296+
atomic_numbers, charges_mm, xyz_qm, xyz_mm, cell, charge
1297+
)
12901298
dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad(
12911299
E.sum(), (xyz_qm, xyz_mm), allow_unused=allow_unused
12921300
)
@@ -1313,7 +1321,9 @@ def _calculate_energy_and_gradients(
13131321
model = base_model.original_name
13141322
try:
13151323
with _torch.jit.optimized_execution(False):
1316-
E = base_model(atomic_numbers, charges_mm, xyz_qm, xyz_mm, charge)
1324+
E = base_model(
1325+
atomic_numbers, charges_mm, xyz_qm, xyz_mm, cell, charge
1326+
)
13171327
dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad(
13181328
E.sum(), (xyz_qm, xyz_mm), allow_unused=allow_unused
13191329
)
@@ -1357,7 +1367,7 @@ def _calculate_energy_and_gradients(
13571367
E_mm_qm_vac, grad_mm_qm_vac = 0.0, _np.zeros_like(xyz_qm_np)
13581368

13591369
# Compute the embedding contributions.
1360-
E = self._emle_mm(atomic_numbers, charges_mm, xyz_qm, xyz_mm, charge)
1370+
E = self._emle_mm(atomic_numbers, charges_mm, xyz_qm, xyz_mm, cell, charge)
13611371
dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad(
13621372
E.sum(), (xyz_qm, xyz_mm), allow_unused=allow_unused
13631373
)
@@ -1448,7 +1458,9 @@ def set_lambda_interpolate(self, lambda_interpolate):
14481458
# Reset the first step flag.
14491459
self._is_first_step = not self._restart
14501460

1451-
def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None):
1461+
def _sire_callback(
1462+
self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, cell=None, idx_mm=None
1463+
):
14521464
"""
14531465
A callback function to be used with Sire.
14541466
@@ -1467,6 +1479,9 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None
14671479
xyz_mm: [[float, float, float]]
14681480
The coordinates of the MM atoms in Angstrom.
14691481
1482+
cell: [[float, float, float], [float, float, float], [float, float, float]]
1483+
The simulation box vectors.
1484+
14701485
idx_mm: [int]
14711486
A list of indices of the MM atoms in the QM/MM region.
14721487
Note that len(idx_mm) <= len(charges_mm) since it only
@@ -1493,6 +1508,8 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None
14931508
charges_mm = _np.array(charges_mm)
14941509
xyz_qm = _np.array(xyz_qm)
14951510
xyz_mm = _np.array(xyz_mm)
1511+
if cell is not None:
1512+
cell = _np.array(cell)
14961513

14971514
# Make sure that the number of QM atoms matches the number of MM charges
14981515
# when using mm embedding.
@@ -1512,6 +1529,7 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None
15121529
charges_mm,
15131530
xyz_qm,
15141531
xyz_mm,
1532+
cell=cell,
15151533
)
15161534

15171535
# Store the number of MM atoms.

emle/models/_ani.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def forward(
348348
charges_mm: Tensor,
349349
xyz_qm: Tensor,
350350
xyz_mm: Tensor,
351+
cell: Optional[Tensor] = None,
351352
qm_charge: int = 0,
352353
) -> Tensor:
353354
"""
@@ -368,6 +369,9 @@ def forward(
368369
xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3)
369370
Positions of MM atoms in Angstrom.
370371
372+
cell: torch.Tensor (3, 3) or (BATCH, 3, 3), optional
373+
The simulation cell vectors in Angstrom.
374+
371375
qm_charge: int or torch.Tensor (BATCH,)
372376
The charge on the QM region.
373377
@@ -404,7 +408,7 @@ def forward(
404408
self._emle._emle_base._emle_aev_computer._aev = self._ani2x.aev_computer._aev
405409

406410
# Get the EMLE energy components.
407-
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge)
411+
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, cell, qm_charge)
408412

409413
# Return the ANI2x and EMLE energy components.
410414
return _torch.stack((E_vac, E_emle[0], E_emle[1]))

emle/models/_emle.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import torchani as _torchani
3535

3636
from torch import Tensor
37-
from typing import Union
37+
from typing import Optional, Union
3838

3939
from . import _patches
4040
from . import EMLEBase as _EMLEBase
@@ -417,6 +417,7 @@ def forward(
417417
charges_mm: Tensor,
418418
xyz_qm: Tensor,
419419
xyz_mm: Tensor,
420+
cell: Optional[Tensor] = None,
420421
qm_charge: Union[int, Tensor] = 0,
421422
) -> Tensor:
422423
"""
@@ -437,6 +438,9 @@ def forward(
437438
xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3)
438439
Positions of MM atoms in Angstrom.
439440
441+
cell: torch.Tensor (3, 3) or (BATCH, 3, 3), optional
442+
The simulation cell vectors in Angstrom.
443+
440444
qm_charge: int or torch.Tensor (BATCH,)
441445
The charge on the QM region.
442446
@@ -461,6 +465,14 @@ def forward(
461465

462466
batch_size = self._atomic_numbers.shape[0]
463467

468+
# Ensure cell is a tensor and repeat for batch size if necessary.
469+
if cell is not None:
470+
if isinstance(cell, _torch.Tensor):
471+
if cell.ndim == 2:
472+
cell = cell.repeat(batch_size, 1, 1).to(self._device)
473+
else:
474+
raise TypeError("'cell' must be of type 'torch.Tensor'")
475+
464476
# Ensure qm_charge is a tensor and repeat for batch size if necessary
465477
if isinstance(qm_charge, int):
466478
qm_charge = _torch.full(

emle/models/_mace.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import torch as _torch
3232
import numpy as _np
3333

34-
from typing import List
34+
from typing import List, Optional
3535

3636
from ._emle import EMLE as _EMLE
3737
from ._emle import _has_nnpops
@@ -198,12 +198,21 @@ def __init__(
198198
)
199199

200200
if not isinstance(mace_model, (list, tuple)):
201-
mace_model = [mace_model] if mace_model is None or isinstance(mace_model, str) else None
201+
mace_model = (
202+
[mace_model]
203+
if mace_model is None or isinstance(mace_model, str)
204+
else None
205+
)
202206

203-
if mace_model is None or any(not isinstance(i, (str, type(None))) for i in mace_model):
204-
raise TypeError("'mace_model' must be a list, tuple, or str, with elements of type str or None")
207+
if mace_model is None or any(
208+
not isinstance(i, (str, type(None))) for i in mace_model
209+
):
210+
raise TypeError(
211+
"'mace_model' must be a list, tuple, or str, with elements of type str or None"
212+
)
205213

206214
from mace.tools.scripts_utils import extract_config_mace_model
215+
207216
self._mace_models = _torch.nn.ModuleList()
208217
for model in mace_model:
209218
source_model = self._load_mace_model(model, device)
@@ -392,9 +401,9 @@ def to(self, *args, **kwargs):
392401
"""
393402
self._emle = self._emle.to(*args, **kwargs)
394403
self._mace = self._mace.to(*args, **kwargs)
395-
self._mace_models = _torch.nn.ModuleList([
396-
model.to(*args, **kwargs) for model in self._mace_models
397-
])
404+
self._mace_models = _torch.nn.ModuleList(
405+
[model.to(*args, **kwargs) for model in self._mace_models]
406+
)
398407
return self
399408

400409
def cpu(self, **kwargs):
@@ -405,9 +414,9 @@ def cpu(self, **kwargs):
405414
self._mace = self._mace.cpu(**kwargs)
406415
if self._atomic_numbers is not None:
407416
self._atomic_numbers = self._atomic_numbers.cpu(**kwargs)
408-
self._mace_models = _torch.nn.ModuleList([
409-
model.cpu(**kwargs) for model in self._mace_models
410-
])
417+
self._mace_models = _torch.nn.ModuleList(
418+
[model.cpu(**kwargs) for model in self._mace_models]
419+
)
411420
return self
412421

413422
def cuda(self, **kwargs):
@@ -418,9 +427,9 @@ def cuda(self, **kwargs):
418427
self._mace = self._mace.cuda(**kwargs)
419428
if self._atomic_numbers is not None:
420429
self._atomic_numbers = self._atomic_numbers.cuda(**kwargs)
421-
self._mace_models = _torch.nn.ModuleList([
422-
model.cuda(**kwargs) for model in self._mace_models
423-
])
430+
self._mace_models = _torch.nn.ModuleList(
431+
[model.cuda(**kwargs) for model in self._mace_models]
432+
)
424433
return self
425434

426435
def double(self):
@@ -429,9 +438,9 @@ def double(self):
429438
"""
430439
self._emle = self._emle.double()
431440
self._mace = self._mace.double()
432-
self._mace_models = _torch.nn.ModuleList([
433-
model.double() for model in self._mace_models
434-
])
441+
self._mace_models = _torch.nn.ModuleList(
442+
[model.double() for model in self._mace_models]
443+
)
435444
return self
436445

437446
def float(self):
@@ -440,9 +449,9 @@ def float(self):
440449
"""
441450
self._emle = self._emle.float()
442451
self._mace = self._mace.float()
443-
self._mace_models = _torch.nn.ModuleList([
444-
model.float() for model in self._mace_models
445-
])
452+
self._mace_models = _torch.nn.ModuleList(
453+
[model.float() for model in self._mace_models]
454+
)
446455
return self
447456

448457
def forward(
@@ -451,6 +460,7 @@ def forward(
451460
charges_mm: Tensor,
452461
xyz_qm: Tensor,
453462
xyz_mm: Tensor,
463+
cell: Optional[Tensor] = None,
454464
qm_charge: int = 0,
455465
) -> Tensor:
456466
"""
@@ -471,6 +481,9 @@ def forward(
471481
xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3)
472482
Positions of MM atoms in Angstrom.
473483
484+
cell: torch.Tensor (3, 3) or (BATCH, 3, 3), optional
485+
The simulation cell vectors in Angstrom.
486+
474487
qm_charge: int
475488
The charge on the QM region.
476489
@@ -489,6 +502,8 @@ def forward(
489502
xyz_qm = xyz_qm.unsqueeze(0)
490503
xyz_mm = xyz_mm.unsqueeze(0)
491504
charges_mm = charges_mm.unsqueeze(0)
505+
if cell is not None and cell.ndim == 2:
506+
cell = cell.unsqueeze(0)
492507

493508
# Store the number of batches.
494509
num_batches = atomic_numbers.shape[0]
@@ -497,8 +512,17 @@ def forward(
497512
num_models = len(self._mace_models)
498513

499514
# Create tensors to store the data for QbC.
500-
self._E_vac_qbc = _torch.empty(num_models, num_batches, dtype=self._dtype, device=device)
501-
self._grads_qbc = _torch.empty(num_models, num_batches, xyz_qm.shape[1], 3, dtype=self._dtype, device=device)
515+
self._E_vac_qbc = _torch.empty(
516+
num_models, num_batches, dtype=self._dtype, device=device
517+
)
518+
self._grads_qbc = _torch.empty(
519+
num_models,
520+
num_batches,
521+
xyz_qm.shape[1],
522+
3,
523+
dtype=self._dtype,
524+
device=device,
525+
)
502526

503527
# Create tensors to store the results.
504528
results_E_vac = _torch.empty(num_batches, dtype=self._dtype, device=device)
@@ -516,6 +540,10 @@ def forward(
516540
xyz_qm[i], None, self._r_max, self._dtype, device
517541
)
518542

543+
# Get the cell for this configuration.
544+
if cell is not None:
545+
cell = cell.to(self._dtype).to(device)
546+
519547
if not _torch.equal(atomic_numbers[i], self._atomic_numbers):
520548
# Update the node attributes if the atomic numbers have changed.
521549
self._node_attrs = (
@@ -557,19 +585,25 @@ def forward(
557585
results_E_vac[i] = E_vac[0] * EV_TO_HARTREE
558586

559587
# Decouple the positions from the computation graph for the next models.
560-
input_dict["positions"] = input_dict["positions"].clone().detach().requires_grad_(True)
588+
input_dict["positions"] = (
589+
input_dict["positions"].clone().detach().requires_grad_(True)
590+
)
561591

562592
# Do inference for the other models.
563593
if len(self._mace_models) > 1:
564594
for j, mace in enumerate(self._mace_models):
565-
E_vac_qbc = mace(input_dict, compute_force=False)["interaction_energy"]
595+
E_vac_qbc = mace(input_dict, compute_force=False)[
596+
"interaction_energy"
597+
]
566598

567599
assert (
568600
E_vac_qbc is not None
569601
), "The model did not return any energy. Please check the input."
570602

571603
# Calculate the gradients
572-
grads_qbc = _torch.autograd.grad([E_vac_qbc], [input_dict["positions"]])[0]
604+
grads_qbc = _torch.autograd.grad(
605+
[E_vac_qbc], [input_dict["positions"]]
606+
)[0]
573607
assert grads_qbc is not None, "Gradient computation failed"
574608

575609
# Store the results.
@@ -585,7 +619,7 @@ def forward(
585619
else:
586620
# Get the EMLE energy components.
587621
E_emle = self._emle(
588-
atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge
622+
atomic_numbers, charges_mm, xyz_qm, xyz_mm, cell, qm_charge
589623
)
590624
results_E_emle_static[i] = E_emle[0][0]
591625
results_E_emle_induced[i] = E_emle[1][0]

0 commit comments

Comments
 (0)