Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
456daec
oeq as optional dependency
abhijeetgangan Jan 23, 2026
f398a28
oeq integration
abhijeetgangan Jan 23, 2026
c48b64c
run oeq test in CI
abhijeetgangan Jan 23, 2026
1f36a8c
update calculator and add tests
abhijeetgangan Jan 23, 2026
ec7fa1b
Update jax-md tests to let them use kernels and fix padding for the i…
abhijeetgangan Jan 23, 2026
3345f7e
add phonon calculation
abhijeetgangan Jan 23, 2026
38eff6a
Better import check for jax
abhijeetgangan Jan 26, 2026
ec50b0c
Merge branch 'main' into ag/oeq_integration
abhijeetgangan Feb 4, 2026
4d8df98
Update uv lock
abhijeetgangan Feb 4, 2026
dab0624
skip test in CI and add suggestions
abhijeetgangan Feb 4, 2026
feb946c
Move instructions
abhijeetgangan Feb 5, 2026
eb1ea39
remove nacl example
abhijeetgangan Feb 5, 2026
089b9cd
revert ag's original; check for 0e unnecessary
teddykoker Feb 5, 2026
545c71c
add kernel to pft configs
teddykoker Feb 5, 2026
065895c
add OEQ_NOTORCH=1 to jax model, make tp_conv a static field
teddykoker Feb 5, 2026
b26a87b
use kernel in jax training
teddykoker Feb 5, 2026
b6e50a0
adding pmap error reproducer (will remove)
teddykoker Feb 5, 2026
a3a85a0
add use_kernel in jax calculator and weight conversion
teddykoker Feb 6, 2026
d449442
fix test
teddykoker Feb 6, 2026
090bd43
merge main
teddykoker Feb 6, 2026
b8a1155
add comment
teddykoker Feb 6, 2026
f09cfa9
rm pmap example/data, update oeq version in uv lock
teddykoker Feb 9, 2026
5802425
move test/torch to tests/torch_tests. This avoids "import torch" succ…
teddykoker Feb 10, 2026
14a9d05
update readme
teddykoker Feb 10, 2026
aa9c2a3
only import oeq without torch if torch is not installed (this lets bo…
teddykoker Feb 10, 2026
8df8f66
update dependencies (new oeq release)
teddykoker Feb 10, 2026
ead1dba
readme
teddykoker Feb 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ Model | Dataset | Theory | Reference
pip install nequix
```

or for torch
to use [OpenEquivariance](https://github.com/PASSIONLab/OpenEquivariance) kernels,

```bash
pip install nequix[oeq]
# needs to be run after installation:
uv pip install openequivariance_extjax --no-build-isolation
```

or for torch (also with kernels):

```bash
pip install nequix[torch]
Expand All @@ -37,14 +45,16 @@ atoms = ...
atoms.calc = NequixCalculator("nequix-mp-1", backend="jax")
```

or if you want to use the faster PyTorch + kernels backend
or if you want to use the torch backend:

```python
...
atoms.calc = NequixCalculator("nequix-mp-1", backend="torch")
...
```

These are typically comparable in speed with kernels.

#### NequixCalculator

Arguments
Expand All @@ -53,7 +63,7 @@ Arguments
- `backend` ({"jax", "torch"}, default "jax"): Compute backend.
- `capacity_multiplier` (float, default 1.1): JAX-only; padding factor to limit recompiles.
- `use_compile` (bool, default True): Torch-only; on GPU, uses `torch.compile()`.
- `use_kernel` (bool, default True): Torch-only; on GPU, use [OpenEquivariance](https://github.com/PASSIONLab/OpenEquivariance) kernels.
- `use_kernel` (bool, default True): on GPU, use [OpenEquivariance](https://github.com/PASSIONLab/OpenEquivariance) kernels.

### Training

Expand Down Expand Up @@ -100,7 +110,7 @@ Then start the training run:
nequix_train configs/nequix-mp-1.yml
```

This will take less than 125 hours on a single 4 x A100 node (<25 hours using the torch + kernels backend). The `batch_size` in the
This will take less than 125 hours on a single 4 x A100 node (<25 hours with kernels). The `batch_size` in the
config is per-device, so you should be able to run this on any number of GPUs
(although hyperparameters like learning rate are often sensitive to global batch
size, so keep in mind).
Expand Down
1 change: 1 addition & 0 deletions configs/nequix-mp-1-pft-no-cotrain.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ hessian_weight: 100.0
val_every: 2
log_every: 100
ema_decay: 0.999
kernel: true
1 change: 1 addition & 0 deletions configs/nequix-mp-1-pft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ hessian_weight: 100.0
val_every: 2
log_every: 100
ema_decay: 0.999
kernel: true
1 change: 1 addition & 0 deletions configs/nequix-oam-1-pft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ hessian_weight: 100.0
val_every: 5
log_every: 100
ema_decay: 0.999
kernel: true
10 changes: 6 additions & 4 deletions nequix/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
model_path: str = None,
capacity_multiplier: float = 1.1, # Only for jax backend
backend: str = "jax",
use_kernel: bool = True, # Only for torch backend
use_kernel: bool = True,
use_compile: bool = True, # Only for torch backend
**kwargs,
):
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
path_backend = "jax" if model_path.suffix == ".nqx" else "torch"
if path_backend == backend:
if backend == "jax":
self.model, self.config = load_model_jax(model_path)
self.model, self.config = load_model_jax(model_path, use_kernel)
else:
from nequix.torch.model import load_model as load_model_torch

Expand All @@ -87,13 +87,15 @@ def __init__(

torch_model, torch_config = load_model_torch(model_path, use_kernel)
print("Converting PyTorch model to JAX ...")
self.model, self.config = convert_model_torch_to_jax(torch_model, torch_config)
self.model, self.config = convert_model_torch_to_jax(
torch_model, torch_config, use_kernel
)
out_path = model_path.parent / f"{model_name}.nqx"
save_model_jax(out_path, self.model, self.config)
else:
from nequix.torch.utils import convert_model_jax_to_torch

jax_model, jax_config = load_model_jax(model_path)
jax_model, jax_config = load_model_jax(model_path, use_kernel)
print("Converting JAX model to PyTorch ...")
self.model, self.config = convert_model_jax_to_torch(
jax_model, jax_config, use_kernel
Expand Down
110 changes: 101 additions & 9 deletions nequix/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import math
from typing import Callable, Optional, Sequence
import os
from typing import Any, Callable, Optional, Sequence

import e3nn_jax as e3nn
import equinox as eqx
Expand All @@ -10,6 +11,22 @@

from nequix.layer_norm import RMSLayerNorm

try:
import torch # noqa: F401
except ImportError:
# allow openequivariance to be imported without torch, but only if it is not
# installed; otherwise, the torch backend won't work if users want to use
# both torch and jax.
os.environ["OEQ_NOTORCH"] = "1"

try:
import openequivariance as oeq
import openequivariance_extjax # noqa: F401

OEQ_AVAILABLE = True
except ImportError:
OEQ_AVAILABLE = False


def bessel_basis(x: jax.Array, num_basis: int, r_max: float) -> jax.Array:
prefactor = 2.0 / r_max
Expand All @@ -32,6 +49,23 @@ def polynomial_cutoff(x: jax.Array, r_max: float, p: float) -> jax.Array:
return out * jnp.where(x < 1.0, 1.0, 0.0)


class Sort(eqx.Module):
irreps: e3nn.Irreps = eqx.field(static=True)
irreps_sorted: e3nn.Irreps = eqx.field(static=True)
slices_sorted: list = eqx.field(static=True)

def __init__(self, irreps: e3nn.Irreps):
self.irreps = irreps
slices = list(irreps.slices())
irreps_sorted, _, inv = irreps.sort()
self.slices_sorted = [slices[i] for i in inv]
self.irreps_sorted = irreps_sorted

def __call__(self, x: jax.Array) -> jax.Array:
chunks = [x[..., s] for s in self.slices_sorted]
return jnp.concatenate(chunks, axis=-1)


class Linear(eqx.Module):
weights: jax.Array
bias: Optional[jax.Array]
Expand Down Expand Up @@ -96,14 +130,18 @@ def __call__(self, x: jax.Array) -> jax.Array:

class NequixConvolution(eqx.Module):
output_irreps: e3nn.Irreps = eqx.field(static=True)
tp_irreps: e3nn.Irreps = eqx.field(static=True)
index_weights: bool = eqx.field(static=True)
avg_n_neighbors: float = eqx.field(static=True)
kernel: bool = eqx.field(static=True)
tp_conv: Optional[Any] = eqx.field(static=True)

radial_mlp: MLP
linear_1: e3nn.equinox.Linear
linear_2: e3nn.equinox.Linear
skip: e3nn.equinox.Linear
layer_norm: Optional[RMSLayerNorm]
sort: Sort

def __init__(
self,
Expand All @@ -119,12 +157,50 @@ def __init__(
avg_n_neighbors: float,
index_weights: bool = True,
layer_norm: bool = False,
kernel: bool = False,
):
self.output_irreps = output_irreps
self.avg_n_neighbors = avg_n_neighbors
self.index_weights = index_weights
self.kernel = kernel

irreps_out_tp = []
instructions = []
for i, (mul, ir_in1) in enumerate(input_irreps):
for j, (_, ir_in2) in enumerate(sh_irreps):
for ir_out in ir_in1 * ir_in2:
if ir_out in output_irreps:
k = len(irreps_out_tp)
irreps_out_tp.append((mul, ir_out))
instructions.append((i, j, k, "uvu", True))

tp_irreps = e3nn.Irreps(irreps_out_tp)
_, _, inv = tp_irreps.sort()
self.tp_irreps = tp_irreps

if kernel:
instructions = [instructions[i] for i in inv]
if not OEQ_AVAILABLE:
raise ImportError(
"OpenEquivariance with JAX support is required for kernel=True. "
"Install both packages:\n"
" uv pip install 'openequivariance[jax]'\n"
" uv pip install 'openequivariance_extjax' --no-build-isolation"
)
problem = oeq.TPProblem(
str(input_irreps),
str(sh_irreps),
str(tp_irreps),
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
self.tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)
else:
self.tp_conv = None

tp_irreps = e3nn.tensor_product(input_irreps, sh_irreps, filter_ir_out=output_irreps)
self.sort = Sort(tp_irreps)
tp_irreps = self.sort.irreps_sorted

k1, k2, k3, k4 = jax.random.split(key, 4)

Expand Down Expand Up @@ -182,14 +258,27 @@ def __call__(
senders: jax.Array,
receivers: jax.Array,
) -> e3nn.IrrepsArray:
messages = self.linear_1(features)[senders]
messages = e3nn.tensor_product(messages, sh, filter_ir_out=self.output_irreps)
messages = self.linear_1(features)
radial_message = jax.vmap(self.radial_mlp)(radial_basis)
messages = messages * radial_message

messages_agg = e3nn.scatter_sum(
messages, dst=receivers, output_size=features.shape[0]
) / jnp.sqrt(jax.lax.stop_gradient(self.avg_n_neighbors))
if self.kernel:
messages_agg = self.sort(
self.tp_conv.forward(
messages.array,
sh.array,
radial_message,
receivers.astype(jnp.int32),
senders.astype(jnp.int32),
)
)
messages_agg = e3nn.IrrepsArray(self.sort.irreps_sorted, messages_agg)
else:
messages = messages[senders]
messages = e3nn.tensor_product(messages, sh, filter_ir_out=self.tp_irreps)
messages = messages * radial_message
messages_agg = e3nn.scatter_sum(messages, dst=receivers, output_size=features.shape[0])

messages_agg = messages_agg / jnp.sqrt(jax.lax.stop_gradient(self.avg_n_neighbors))

skip = self.skip(species, features) if self.index_weights else self.skip(features)
features = self.linear_2(messages_agg) + skip
Expand Down Expand Up @@ -237,6 +326,7 @@ def __init__(
avg_n_neighbors: float = 1.0,
atom_energies: Optional[Sequence[float]] = None,
layer_norm: bool = False,
kernel: bool = False,
):
self.lmax = lmax
self.cutoff = cutoff
Expand Down Expand Up @@ -271,6 +361,7 @@ def __init__(
avg_n_neighbors=avg_n_neighbors,
index_weights=index_weights,
layer_norm=layer_norm,
kernel=kernel,
)
)

Expand Down Expand Up @@ -446,7 +537,7 @@ def save_model(path: str, model: eqx.Module, config: dict):
eqx.tree_serialise_leaves(f, model)


def load_model(path: str) -> tuple[Nequix, dict]:
def load_model(path: str, kernel: bool = False) -> tuple[Nequix, dict]:
"""Load a model and its config from a file."""
with open(path, "rb") as f:
config = json.loads(f.readline().decode())
Expand All @@ -467,6 +558,7 @@ def load_model(path: str) -> tuple[Nequix, dict]:
shift=config["shift"],
scale=config["scale"],
avg_n_neighbors=config["avg_n_neighbors"],
kernel=kernel,
# NOTE: atom_energies will be in model weights
)
model = eqx.tree_deserialise_leaves(f, model)
Expand Down
2 changes: 1 addition & 1 deletion nequix/pft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def train(config_path):
with open(config_path, "r") as f:
config = yaml.safe_load(f)

model, original_config = load_model(config["finetune_from"])
model, original_config = load_model(config["finetune_from"], config["kernel"])

if config["optimizer"] == "muon":
optim = optax.chain(
Expand Down
3 changes: 2 additions & 1 deletion nequix/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def convert_layer_torch_to_jax(layer_idx, torch_model, jax_model):
return jax_model


def convert_model_torch_to_jax(torch_model, config):
def convert_model_torch_to_jax(torch_model, config, use_kernel):
jax_model = Nequix(
key=jax.random.key(0),
n_species=len(config["atomic_numbers"]),
Expand All @@ -150,6 +150,7 @@ def convert_model_torch_to_jax(torch_model, config):
scale=config["scale"],
avg_n_neighbors=config["avg_n_neighbors"],
atom_energies=[config["atom_energies"][str(n)] for n in config["atomic_numbers"]],
kernel=use_kernel,
)
for layer_idx in range(len(torch_model.layers)):
jax_model = convert_layer_torch_to_jax(layer_idx, torch_model, jax_model)
Expand Down
1 change: 1 addition & 0 deletions nequix/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def train(config_path: str):
scale=stats["scale"],
avg_n_neighbors=stats["avg_n_neighbors"],
atom_energies=atom_energies,
kernel=config["kernel"],
)
if "finetune_from" in config and Path(config["finetune_from"]).exists():
if "atom_energies" in config:
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,16 @@ jax-md = { git = "https://github.com/jax-md/jax-md.git" }
[project.optional-dependencies]
torch = [
"e3nn>=0.5.8",
"openequivariance==0.4.1",
"openequivariance>=0.5.4",
"torch==2.7.0",
"torch-geometric>=2.6.1",
"setuptools",
]
oeq = [
# You need to install openequivariance_extjax separately
# uv pip install openequivariance_extjax --no-build-isolation
"openequivariance[jax]>=0.5.4",
]
pft = [
"phonopy>=2.43.1",
]
4 changes: 3 additions & 1 deletion scripts/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def main():
save_model_torch(output_path, torch_model, torch_config)
elif input_backend == "torch" and output_backend == "jax":
torch_model, torch_config = load_model_torch(input_path)
jax_model, jax_config = convert_model_torch_to_jax(torch_model, torch_config)
jax_model, jax_config = convert_model_torch_to_jax(
torch_model, torch_config, use_kernel=False
)
save_model_jax(output_path, jax_model, jax_config)
else:
raise ValueError(f"invalid input and output backends: {input_backend} and {output_backend}")
Expand Down
Loading