-
Notifications
You must be signed in to change notification settings - Fork 54
Added the trexio support #393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2c6c089
f254e8b
32a10a8
7580944
0a68e27
6d6918a
57af5c7
80c4bee
7227e85
d064d32
067e26d
a4d7e15
e51f55f
6ec12e1
f64257e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| # IODATA is an input and output module for quantum chemistry. | ||
| # Copyright (C) 2011-2019 The IODATA Development Team | ||
| # | ||
| # This file is part of IODATA. | ||
| # | ||
| # IODATA is free software; you can redistribute it and/or | ||
| # modify it under the terms of the GNU General Public License | ||
| # as published by the Free Software Foundation; either version 3 | ||
| # of the License, or (at your option) any later version. | ||
| # | ||
| # IODATA is distributed in the hope that it will be useful, | ||
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
| # GNU General Public License for more details. | ||
| # | ||
| # You should have received a copy of the GNU General Public License | ||
| # along with this program; if not, see <http://www.gnu.org/licenses/> | ||
| # -- | ||
| """TrexIO file format.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| from typing import TextIO | ||
|
|
||
| import numpy as np | ||
|
|
||
| from ..docstrings import document_dump_one, document_load_one | ||
| from ..iodata import IOData | ||
| from ..utils import LineIterator, LoadError | ||
|
|
||
| __all__ = () | ||
|
|
||
| PATTERNS = ["*.trexio"] | ||
|
|
||
|
|
||
| def _import_trexio(): | ||
| """Lazily import the trexio module.""" | ||
| try: | ||
| import trexio # noqa: PLC0415 | ||
| except ImportError: | ||
| return None | ||
| return trexio | ||
|
|
||
|
|
||
| @document_load_one( | ||
| "TREXIO", | ||
| ["atcoords", "atnums"], | ||
| ["charge", "nelec", "spinpol"], | ||
| ) | ||
| def load_one(lit: LineIterator) -> dict: | ||
| """Do not edit this docstring. It will be overwritten.""" | ||
| trexio = _import_trexio() | ||
| filename = lit.filename | ||
|
|
||
| if trexio is None: | ||
| raise LoadError( | ||
| "Reading TREXIO files requires the 'trexio' Python package.", | ||
| filename, | ||
| ) | ||
|
|
||
| try: | ||
| # TrexIO needs to open the file itself. | ||
| # We cannot close lit.f because LineIterator might use it or context manager needs it. | ||
| # On Unix, opening same file twice for read is usually OK. | ||
| with trexio.File(filename, "r", back_end=trexio.TREXIO_HDF5) as tfile: | ||
| n_nuc = trexio.read_nucleus_num(tfile) | ||
| charges = np.asarray(trexio.read_nucleus_charge(tfile), dtype=float) | ||
| coords = np.asarray(trexio.read_nucleus_coord(tfile), dtype=float) | ||
|
|
||
| try: | ||
| nelec = int(trexio.read_electron_num(tfile)) | ||
| except trexio.Error: | ||
| nelec = None | ||
|
|
||
| try: | ||
| n_up = int(trexio.read_electron_up_num(tfile)) | ||
| n_dn = int(trexio.read_electron_dn_num(tfile)) | ||
| spinpol = n_up - n_dn | ||
| except trexio.Error: | ||
| spinpol = None | ||
|
|
||
| except LoadError: | ||
| raise | ||
| except Exception as exc: | ||
| raise LoadError(f"Failed to read TREXIO file: {exc}", filename) from exc | ||
|
|
||
| # Validate data consistency after reading | ||
| if charges.shape[0] != n_nuc or coords.shape[0] != n_nuc: | ||
| raise LoadError( | ||
| "Inconsistent nucleus.* fields in TREXIO file.", | ||
| filename, | ||
| ) | ||
|
|
||
| atnums = np.rint(charges).astype(int) | ||
|
|
||
| result: dict = { | ||
| "atcoords": coords, | ||
| "atnums": atnums, | ||
| } | ||
|
|
||
| if nelec is not None: | ||
| result["nelec"] = nelec | ||
| result["charge"] = float(charges.sum() - nelec) | ||
| if spinpol is not None: | ||
| result["spinpol"] = spinpol | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| @document_dump_one( | ||
| "TREXIO", | ||
| ["atcoords", "atnums"], | ||
| ["charge", "nelec", "spinpol"], | ||
| ) | ||
| def dump_one(f: TextIO, data: IOData): | ||
| """Do not edit this docstring. It will be overwritten.""" | ||
| trexio = _import_trexio() | ||
| if trexio is None: | ||
| raise RuntimeError("Writing TREXIO files requires the 'trexio' Python package.") | ||
|
|
||
| if data.atcoords is None or data.atnums is None: | ||
| raise RuntimeError("TREXIO writer needs atcoords and atnums.") | ||
| if data.atcoords.shape[0] != data.atnums.shape[0]: | ||
| raise RuntimeError("Inconsistent number of atoms in atcoords and atnums.") | ||
|
|
||
| try: | ||
| filename = f.name | ||
| except AttributeError as exc: | ||
| raise RuntimeError( | ||
| "TREXIO writer expects a real file object with a .name attribute." | ||
| ) from exc | ||
|
|
||
| atcoords = np.asarray(data.atcoords, dtype=float) | ||
| atnums = np.asarray(data.atnums, dtype=float) | ||
| nelec = int(data.nelec) if data.nelec is not None else None | ||
| spinpol = int(data.spinpol) if data.spinpol is not None else None | ||
|
|
||
| # TrexIO needs to open the file itself. We close the file handle provided by api.py | ||
| # to avoid conflicts (e.g. file locking). api.py will harmlessly close it again. | ||
| f.close() | ||
| if os.path.exists(filename): | ||
| os.remove(filename) | ||
|
|
||
| with trexio.File(filename, "w", back_end=trexio.TREXIO_HDF5) as tfile: | ||
| trexio.write_nucleus_num(tfile, len(atnums)) | ||
| trexio.write_nucleus_charge(tfile, atnums.astype(float)) | ||
| trexio.write_nucleus_coord(tfile, atcoords) | ||
|
|
||
| if nelec is not None: | ||
| trexio.write_electron_num(tfile, nelec) | ||
| if spinpol is not None: | ||
| # Check for consistency between nelec and spinpol | ||
| if abs((nelec + spinpol) % 2) > 1.0e-8: | ||
| raise ValueError( | ||
| f"Inconsistent nelec ({nelec}) and spinpol ({spinpol}). " | ||
| "Sum and difference must be even numbers." | ||
| ) | ||
| n_up = (nelec + spinpol) // 2 | ||
| n_dn = (nelec - spinpol) // 2 | ||
| trexio.write_electron_up_num(tfile, n_up) | ||
| trexio.write_electron_dn_num(tfile, n_dn) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| # IODATA is an input and output module for quantum chemistry. | ||
| # Copyright (C) 2011-2019 The IODATA Development Team | ||
| # | ||
| # This file is part of IODATA. | ||
| # | ||
| # IODATA is free software; you can redistribute it and/or | ||
| # modify it under the terms of the GNU General Public License | ||
| # as published by the Free Software Foundation; either version 3 | ||
| # of the License, or (at your option) any later version. | ||
| # | ||
| # IODATA is distributed in the hope that it will be useful, | ||
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
| # GNU General Public License for more details. | ||
| # | ||
| # You should have received a copy of the GNU General Public License | ||
| # along with this program; if not, see <http://www.gnu.org/licenses/> | ||
| # -- | ||
| import importlib.util | ||
| import os | ||
| import subprocess | ||
| import sys | ||
|
|
||
| import pytest | ||
|
|
||
|
|
||
| @pytest.mark.skipif(sys.platform.startswith("win"), reason="TrexIO issues on Windows") | ||
| def test_load_dump_consistency(tmp_path): | ||
| """Check if dumping and loading a TREXIO file results in the same data. | ||
|
|
||
| Runs in a subprocess to avoid segmentation faults caused by conflict | ||
| between pytest execution model and trexio C-extension. | ||
| """ | ||
| # Skip tests if trexio is not installed, but do NOT import it here to avoid segfaults | ||
| if importlib.util.find_spec("trexio") is None: | ||
| pytest.skip("trexio not installed") | ||
| script = """ | ||
| import numpy as np | ||
| import os | ||
| import sys | ||
|
|
||
| from iodata import IOData | ||
| from iodata.api import load_one, dump_one | ||
|
|
||
| atcoords = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) | ||
| atnums = np.array([1, 1]) | ||
| nelec = 2.0 | ||
| spinpol = 0 | ||
| iodata_orig = IOData(atcoords=atcoords, atnums=atnums, nelec=nelec, spinpol=spinpol) | ||
|
|
||
| filename = "test.trexio" | ||
| if os.path.exists(filename): | ||
| os.remove(filename) | ||
|
|
||
| print(f"Dumping to {filename}") | ||
| dump_one(iodata_orig, filename, fmt="trexio") | ||
|
|
||
| print(f"Loading from {filename}") | ||
| iodata_new = load_one(filename, fmt="trexio") | ||
|
|
||
| print("Verifying data...") | ||
| np.testing.assert_allclose(iodata_new.atcoords, atcoords, err_msg="atcoords mismatch") | ||
| np.testing.assert_equal(iodata_new.atnums, atnums, err_msg="atnums mismatch") | ||
| np.testing.assert_allclose( | ||
| float(iodata_new.nelec), | ||
| nelec, | ||
| rtol=1.0e-8, | ||
| atol=1.0e-12, | ||
| err_msg=f"nelec mismatch: {iodata_new.nelec} != {nelec}", | ||
| ) | ||
| np.testing.assert_allclose( | ||
| float(iodata_new.charge), | ||
| 0.0, | ||
| rtol=1.0e-8, | ||
| atol=1.0e-12, | ||
| err_msg=f"charge mismatch: {iodata_new.charge} != 0.0", | ||
| ) | ||
| assert int(iodata_new.spinpol) == spinpol, ( | ||
| f"spinpol mismatch: {iodata_new.spinpol} != {spinpol}" | ||
| ) | ||
|
|
||
| print("Verification passed") | ||
| """ | ||
| script_file = tmp_path / "verify_trexio_subprocess.py" | ||
| script_file.write_text(script, encoding="utf-8") | ||
|
|
||
| # Determine project root (assuming this test is in iodata/test/) | ||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| project_root = os.path.abspath(os.path.join(current_dir, "../..")) | ||
|
|
||
| # Add project root to PYTHONPATH to ensure local iodata code is used | ||
| env = os.environ.copy() | ||
| current_pythonpath = env.get("PYTHONPATH", "") | ||
| env["PYTHONPATH"] = f"{project_root}:{current_pythonpath}" | ||
|
|
||
| subprocess.check_call([sys.executable, str(script_file)], cwd=tmp_path, env=env) | ||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -53,6 +53,7 @@ dev = [ | |||
| "sphinx_autodoc_typehints", | ||||
| "sphinx-copybutton", | ||||
| "sympy", | ||||
|
|
||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please avoid unintended changes due to edits that were later undone. (E.g. by checking the diffs.) In order to test the trexio code, it would be useful to include a set of extra = [
"trexio",
]With this, you can update the following line in the CI workflow: iodata/.github/workflows/pytest.yaml Line 48 in fa4d09d
E.g. change it to: |
||||
| ] | ||||
|
|
||||
| [project.urls] | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This construction with the subprocess should be avoided. You can make the test conditional on the availability of the trexio package without a subprocess. Something along the following lines: