Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions iodata/formats/trexio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
"""TrexIO file format."""

from __future__ import annotations

import os
from typing import TextIO

import numpy as np

from ..docstrings import document_dump_one, document_load_one
from ..iodata import IOData
from ..utils import LineIterator, LoadError

__all__ = ()

PATTERNS = ["*.trexio"]


def _import_trexio():
"""Lazily import the trexio module."""
try:
import trexio # noqa: PLC0415
except ImportError:
return None
return trexio


@document_load_one(
"TREXIO",
["atcoords", "atnums"],
["charge", "nelec", "spinpol"],
)
def load_one(lit: LineIterator) -> dict:
"""Do not edit this docstring. It will be overwritten."""
trexio = _import_trexio()
filename = lit.filename

if trexio is None:
raise LoadError(
"Reading TREXIO files requires the 'trexio' Python package.",
filename,
)

try:
# TrexIO needs to open the file itself.
# We cannot close lit.f because LineIterator might use it or context manager needs it.
# On Unix, opening same file twice for read is usually OK.
with trexio.File(filename, "r", back_end=trexio.TREXIO_HDF5) as tfile:
n_nuc = trexio.read_nucleus_num(tfile)
charges = np.asarray(trexio.read_nucleus_charge(tfile), dtype=float)
coords = np.asarray(trexio.read_nucleus_coord(tfile), dtype=float)

try:
nelec = int(trexio.read_electron_num(tfile))
except trexio.Error:
nelec = None

try:
n_up = int(trexio.read_electron_up_num(tfile))
n_dn = int(trexio.read_electron_dn_num(tfile))
spinpol = n_up - n_dn
except trexio.Error:
spinpol = None

except LoadError:
raise
except Exception as exc:
raise LoadError(f"Failed to read TREXIO file: {exc}", filename) from exc

# Validate data consistency after reading
if charges.shape[0] != n_nuc or coords.shape[0] != n_nuc:
raise LoadError(
"Inconsistent nucleus.* fields in TREXIO file.",
filename,
)

atnums = np.rint(charges).astype(int)

result: dict = {
"atcoords": coords,
"atnums": atnums,
}

if nelec is not None:
result["nelec"] = nelec
result["charge"] = float(charges.sum() - nelec)
if spinpol is not None:
result["spinpol"] = spinpol

return result


@document_dump_one(
"TREXIO",
["atcoords", "atnums"],
["charge", "nelec", "spinpol"],
)
def dump_one(f: TextIO, data: IOData):
"""Do not edit this docstring. It will be overwritten."""
trexio = _import_trexio()
if trexio is None:
raise RuntimeError("Writing TREXIO files requires the 'trexio' Python package.")

if data.atcoords is None or data.atnums is None:
raise RuntimeError("TREXIO writer needs atcoords and atnums.")
if data.atcoords.shape[0] != data.atnums.shape[0]:
raise RuntimeError("Inconsistent number of atoms in atcoords and atnums.")

try:
filename = f.name
except AttributeError as exc:
raise RuntimeError(
"TREXIO writer expects a real file object with a .name attribute."
) from exc

atcoords = np.asarray(data.atcoords, dtype=float)
atnums = np.asarray(data.atnums, dtype=float)
nelec = int(data.nelec) if data.nelec is not None else None
spinpol = int(data.spinpol) if data.spinpol is not None else None

# TrexIO needs to open the file itself. We close the file handle provided by api.py
# to avoid conflicts (e.g. file locking). api.py will harmlessly close it again.
f.close()
if os.path.exists(filename):
os.remove(filename)

with trexio.File(filename, "w", back_end=trexio.TREXIO_HDF5) as tfile:
trexio.write_nucleus_num(tfile, len(atnums))
trexio.write_nucleus_charge(tfile, atnums.astype(float))
trexio.write_nucleus_coord(tfile, atcoords)

if nelec is not None:
trexio.write_electron_num(tfile, nelec)
if spinpol is not None:
# Check for consistency between nelec and spinpol
if abs((nelec + spinpol) % 2) > 1.0e-8:
raise ValueError(
f"Inconsistent nelec ({nelec}) and spinpol ({spinpol}). "
"Sum and difference must be even numbers."
)
n_up = (nelec + spinpol) // 2
n_dn = (nelec - spinpol) // 2
trexio.write_electron_up_num(tfile, n_up)
trexio.write_electron_dn_num(tfile, n_dn)
96 changes: 96 additions & 0 deletions iodata/test/test_trexio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
import importlib.util
import os
import subprocess
import sys

import pytest


@pytest.mark.skipif(sys.platform.startswith("win"), reason="TrexIO issues on Windows")
def test_load_dump_consistency(tmp_path):
"""Check if dumping and loading a TREXIO file results in the same data.

Runs in a subprocess to avoid segmentation faults caused by conflict
between pytest execution model and trexio C-extension.
"""
# Skip tests if trexio is not installed, but do NOT import it here to avoid segfaults
if importlib.util.find_spec("trexio") is None:
pytest.skip("trexio not installed")
script = """
import numpy as np
import os
import sys

from iodata import IOData
from iodata.api import load_one, dump_one

atcoords = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]])
atnums = np.array([1, 1])
nelec = 2.0
spinpol = 0
iodata_orig = IOData(atcoords=atcoords, atnums=atnums, nelec=nelec, spinpol=spinpol)

filename = "test.trexio"
if os.path.exists(filename):
os.remove(filename)

print(f"Dumping to {filename}")
dump_one(iodata_orig, filename, fmt="trexio")

print(f"Loading from {filename}")
iodata_new = load_one(filename, fmt="trexio")

print("Verifying data...")
np.testing.assert_allclose(iodata_new.atcoords, atcoords, err_msg="atcoords mismatch")
np.testing.assert_equal(iodata_new.atnums, atnums, err_msg="atnums mismatch")
np.testing.assert_allclose(
float(iodata_new.nelec),
nelec,
rtol=1.0e-8,
atol=1.0e-12,
err_msg=f"nelec mismatch: {iodata_new.nelec} != {nelec}",
)
np.testing.assert_allclose(
float(iodata_new.charge),
0.0,
rtol=1.0e-8,
atol=1.0e-12,
err_msg=f"charge mismatch: {iodata_new.charge} != 0.0",
)
assert int(iodata_new.spinpol) == spinpol, (
f"spinpol mismatch: {iodata_new.spinpol} != {spinpol}"
)

print("Verification passed")
"""
script_file = tmp_path / "verify_trexio_subprocess.py"
script_file.write_text(script, encoding="utf-8")

# Determine project root (assuming this test is in iodata/test/)
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "../.."))

# Add project root to PYTHONPATH to ensure local iodata code is used
env = os.environ.copy()
current_pythonpath = env.get("PYTHONPATH", "")
env["PYTHONPATH"] = f"{project_root}:{current_pythonpath}"

subprocess.check_call([sys.executable, str(script_file)], cwd=tmp_path, env=env)
Comment on lines +37 to +96
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This construction with the subprocess should be avoided. You can make the test conditional on the availability of the trexio package without a subprocess. Something along the following lines:

try:
    import trexio
except ImportError:
    trexio = None

@pytest.mark.skipif(trexio is None, reason="requires trexio")
def test_something():
    ...

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dev = [
"sphinx_autodoc_typehints",
"sphinx-copybutton",
"sympy",

Copy link
Member

@tovrstra tovrstra Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid unintended changes due to edits that were later undone. (E.g. by checking the diffs.)

In order to test the trexio code, it would be useful to include a set of extra optional dependencies like this:

extra = [
    "trexio",
]

With this, you can update the following line in the CI workflow:

run: pip install -e .[dev]

E.g. change it to:

        run: pip install -e .[dev,extra]

]

[project.urls]
Expand Down
Loading