Skip to content

Commit 29a34f0

Browse files
committed
Use chemfiles structures as input
1 parent be2baf7 commit 29a34f0

6 files changed

Lines changed: 502 additions & 11 deletions

File tree

python/chemiscope/structures/__init__.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
import warnings
3+
from typing import Sequence
34

45
from ._ase import ( # noqa: F401
56
_ase_all_atomic_environments,
@@ -9,6 +10,12 @@
910
ase_tensors_to_ellipsoids,
1011
ase_vectors_to_arrows,
1112
)
13+
from ._chemfiles import ( # noqa: F401
14+
_chemfiles_all_atomic_environments,
15+
_chemfiles_extract_properties,
16+
_chemfiles_to_json,
17+
_chemfiles_valid_structures,
18+
)
1219
from ._mda import ( # noqa: F401
1320
_mda_all_atomic_environments,
1421
_mda_extract_properties,
@@ -34,7 +41,7 @@ def _chemiscope_valid_structures(structures):
3441
:return: tuple (structures as list, boolean indicating if structures are valid)
3542
"""
3643

37-
if not hasattr(structures, "__iter__"):
44+
if not isinstance(structures, Sequence):
3845
return structures, False
3946

4047
first_structure = structures[0]
@@ -75,15 +82,22 @@ def _guess_adapter(structures):
7582
if use_ase:
7683
return ase_structures, "ASE"
7784

78-
stk_structures, use_stk = _stk_valid_structures(structures)
79-
if use_stk:
80-
return stk_structures, "stk"
85+
chemfiles_structures, use_chemfiles = _chemfiles_valid_structures(structures)
86+
if use_chemfiles:
87+
return chemfiles_structures, "chemfiles"
8188

8289
mda_structures, use_mda = _mda_valid_structures(structures)
8390
if use_mda:
8491
return mda_structures, "mda"
8592

86-
raise Exception(f"unknown structure type: '{structures[0].__class__.__name__}'")
93+
stk_structures, use_stk = _stk_valid_structures(structures)
94+
if use_stk:
95+
return stk_structures, "stk"
96+
97+
if isinstance(structures, Sequence):
98+
raise Exception(f"unknown structure type: '{structures[0].__class__.__name__}'")
99+
else:
100+
raise Exception(f"unknown structure type: '{structures.__class__.__name__}'")
87101

88102

89103
def structures_to_json(structures):
@@ -103,12 +117,14 @@ def structures_to_json(structures):
103117
json_data = structures
104118
elif adapter == "ASE":
105119
json_data = [_ase_to_json(s) for s in structures]
106-
elif adapter == "stk":
107-
json_data = [_stk_to_json(s) for s in structures]
120+
elif adapter == "chemfiles":
121+
json_data = [_chemfiles_to_json(s) for s in structures]
108122
elif adapter == "mda":
109123
# Be careful of the lazy loading of `structures.atoms`, which is updated during
110124
# the iteration of the trajectory
111125
json_data = [_mda_to_json(structures) for _ in structures.universe.trajectory]
126+
elif adapter == "stk":
127+
json_data = [_stk_to_json(s) for s in structures]
112128
else:
113129
raise Exception("reached unreachable code")
114130

@@ -139,10 +155,10 @@ def extract_properties(structures=None, only=None, *, environments=None, frames=
139155

140156
if adapter == "ASE":
141157
return _ase_extract_properties(structures, only, environments)
142-
158+
elif adapter == "chemfiles":
159+
return _chemfiles_extract_properties(structures, only, environments)
143160
elif adapter == "mda":
144161
return _mda_extract_properties(structures, only, environments)
145-
146162
elif adapter == "stk":
147163
raise RuntimeError(
148164
"stk structures do not contain properties, you must manually provide them"
@@ -178,9 +194,11 @@ def all_atomic_environments(structures=None, cutoff=3.5, *, frames=None):
178194

179195
if adapter == "ASE":
180196
return _ase_all_atomic_environments(structures, cutoff)
181-
elif adapter == "stk":
182-
return _stk_all_atomic_environments(structures, cutoff)
197+
elif adapter == "chemfiles":
198+
return _chemfiles_all_atomic_environments(structures, cutoff)
183199
elif adapter == "mda":
184200
return _mda_all_atomic_environments(structures, cutoff)
201+
elif adapter == "stk":
202+
return _stk_all_atomic_environments(structures, cutoff)
185203
else:
186204
raise Exception("reached unreachable code")

python/chemiscope/structures/_ase.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def _ase_valid_structures(structures):
1919
except TypeError:
2020
return [], False
2121

22+
if len(structures_list) == 0:
23+
return [], False
24+
2225
if HAVE_ASE and isinstance(structures_list[0], ase.Atoms):
2326
for structure in structures_list:
2427
assert isinstance(structure, ase.Atoms)
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
import warnings
2+
from typing import Sequence
3+
4+
import numpy as np
5+
6+
from ._ase import _remove_invalid_properties
7+
8+
9+
try:
10+
import chemfiles
11+
12+
if chemfiles.__version__ < "0.10.0" or chemfiles.__version__ > "0.11.0":
13+
print(
14+
"chemiscope requires chemfiles version >=0.10,<0.11; "
15+
f"but version {chemfiles.__version__} is installed."
16+
)
17+
18+
from chemfiles import Frame, MemoryTrajectory, Trajectory
19+
20+
HAVE_CHEMFILES = True
21+
except ImportError:
22+
HAVE_CHEMFILES = False
23+
24+
25+
def _chemfiles_valid_structures(structures):
26+
if not HAVE_CHEMFILES:
27+
return structures, False
28+
29+
if isinstance(structures, (Trajectory, MemoryTrajectory)):
30+
return [structures.read_step(i) for i in range(structures.nsteps)], True
31+
32+
elif isinstance(structures, Frame):
33+
return [structures], True
34+
35+
elif (
36+
isinstance(structures, Sequence)
37+
and len(structures) > 0
38+
and isinstance(structures[0], Frame)
39+
):
40+
for structure in structures:
41+
assert isinstance(structure, Frame)
42+
return structures, True
43+
44+
else:
45+
return structures, False
46+
47+
48+
# fmt: off
49+
ELEMENTS = [
50+
"X", "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si",
51+
"P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni",
52+
"Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo",
53+
"Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba",
54+
"La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb",
55+
"Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po",
56+
"At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf",
57+
"Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn",
58+
"Nh", "Fl", "Mc", "Lv", "Ts", "Og"
59+
]
60+
# fmt: on
61+
62+
63+
def _chemfiles_to_json(frame):
64+
"""Implementation of structures_to_json for chemfiles' ``Frame``."""
65+
66+
BOND_ORDERS_TO_NUMERIC = {
67+
chemfiles.BondOrder.Unknown: 1,
68+
chemfiles.BondOrder.Single: 1,
69+
chemfiles.BondOrder.Double: 2,
70+
chemfiles.BondOrder.Triple: 3,
71+
# 3Dmol seems to cap bond order at 3
72+
chemfiles.BondOrder.Quadruple: 3,
73+
chemfiles.BondOrder.Quintuplet: 3,
74+
chemfiles.BondOrder.Amide: 1,
75+
chemfiles.BondOrder.Aromatic: 4,
76+
}
77+
78+
data = {}
79+
data["size"] = len(frame.atoms)
80+
data["names"] = [atom.name for atom in frame.atoms]
81+
82+
elements = []
83+
all_have_element = True
84+
for atom in frame.atoms:
85+
atomic_number = atom.atomic_number
86+
if atomic_number == 0:
87+
all_have_element = False
88+
break
89+
else:
90+
elements.append(ELEMENTS[atomic_number])
91+
92+
if all_have_element:
93+
data["elements"] = elements
94+
95+
# data["elements"] = TODO
96+
positions = frame.positions
97+
data["x"] = [float(positions[i][0]) for i in range(data["size"])]
98+
data["y"] = [float(positions[i][1]) for i in range(data["size"])]
99+
data["z"] = [float(positions[i][2]) for i in range(data["size"])]
100+
101+
if frame.cell.shape != chemfiles.CellShape.Infinite:
102+
data["cell"] = frame.cell.matrix.T.flatten().tolist()
103+
104+
# bonds
105+
topology = frame.topology
106+
if len(topology.bonds) > 0:
107+
data["bonds"] = [
108+
(int(bond[0]), int(bond[1]), BOND_ORDERS_TO_NUMERIC[order])
109+
for bond, order in zip(topology.bonds, topology.bonds_orders, strict=True)
110+
]
111+
112+
# biomolecule-specific information
113+
chains = []
114+
resnames = []
115+
resids = []
116+
hetatom = []
117+
has_biomol_info = False
118+
for atom_i in range(data["size"]):
119+
residue = frame.topology.residue_for_atom(atom_i)
120+
if residue is None:
121+
continue
122+
123+
has_biomol_info = True
124+
resids.append(residue.id)
125+
resnames.append(residue.name)
126+
127+
residue_properties = residue.list_properties()
128+
if "chainname" in residue_properties:
129+
chains.append(residue["chainname"])
130+
else:
131+
chains.append("")
132+
133+
if "is_standard_pdb" in residue_properties:
134+
hetatom.append(not residue["is_standard_pdb"])
135+
else:
136+
hetatom.append(True)
137+
138+
if has_biomol_info:
139+
data["chains"] = chains
140+
data["resnames"] = resnames
141+
data["resids"] = resids
142+
data["hetatom"] = hetatom
143+
144+
return data
145+
146+
147+
def _chemfiles_all_atomic_environments(structures, cutoff):
148+
environments = []
149+
for frame_i, frame in enumerate(structures):
150+
for atom_i in range(len(frame.atoms)):
151+
environments.append((frame_i, atom_i, cutoff))
152+
153+
return environments
154+
155+
156+
def _chemfiles_get_structure_properties(frames):
157+
# extract the set of common properties between all frames
158+
all_properties = {}
159+
extra = set()
160+
161+
for name in frames[0].list_properties():
162+
all_properties[name] = [frames[0][name]]
163+
164+
for frame in frames[1:]:
165+
current_properties = frame.list_properties()
166+
for name in current_properties:
167+
if name in all_properties:
168+
all_properties[name].append(frame[name])
169+
else:
170+
extra.add(name)
171+
172+
for name in list(all_properties.keys()):
173+
if name not in current_properties:
174+
all_properties.pop(name, None)
175+
extra.add(name)
176+
177+
if len(extra) != 0:
178+
warnings.warn(
179+
"the following structure properties are only defined for a subset "
180+
f"of structures: {list(sorted(extra))}; they will be ignored",
181+
stacklevel=2,
182+
)
183+
184+
# ensures that if a property is a mix of strings and numbers, everything
185+
# is converted to string (as these should be categorical properties)
186+
for name in all_properties.keys():
187+
property = all_properties[name]
188+
if any(isinstance(x, str) for x in property):
189+
all_properties[name] = [str(x) for x in property]
190+
191+
return all_properties
192+
193+
194+
def _chemfiles_get_atom_properties(frames, environments):
195+
assert environments is not None
196+
# extract the set of common properties between all atoms in all frames
197+
all_properties = {}
198+
extra = set()
199+
200+
frame = frames[environments[0][0]]
201+
atom = frame.atoms[environments[0][1]]
202+
for name in atom.list_properties():
203+
all_properties[name] = [atom[name]]
204+
205+
for frame_i, atom_i, _ in environments[1:]:
206+
frame = frames[frame_i]
207+
atom = frame.atoms[atom_i]
208+
209+
current_properties = atom.list_properties()
210+
for name in current_properties:
211+
if name in all_properties:
212+
all_properties[name].append(atom[name])
213+
else:
214+
extra.add(name)
215+
216+
for name in list(all_properties.keys()):
217+
if name not in current_properties:
218+
all_properties.pop(name, None)
219+
extra.add(name)
220+
221+
if len(extra) != 0:
222+
warnings.warn(
223+
"the following atomic properties are only defined for a subset "
224+
f"of structures: {list(sorted(extra))}; they will be ignored",
225+
stacklevel=2,
226+
)
227+
228+
return all_properties
229+
230+
231+
def _chemfiles_extract_properties(structures, only=None, environments=None):
232+
"""implementation of ``extract_properties`` for chemfiles"""
233+
all_properties = _chemfiles_get_structure_properties(structures)
234+
if only is None:
235+
selected = all_properties
236+
else:
237+
selected = {}
238+
for name in only:
239+
if name in all_properties.keys():
240+
selected[name] = all_properties[name]
241+
242+
# create property in the format expected by create_input
243+
properties = {
244+
name: {"target": "structure", "values": np.asarray(value)}
245+
for name, value in selected.items()
246+
}
247+
248+
_remove_invalid_properties(properties, "chemfiles")
249+
250+
if environments is None:
251+
environments = _chemfiles_all_atomic_environments(structures, cutoff=0.0)
252+
253+
atom_properties = _chemfiles_get_atom_properties(structures, environments)
254+
255+
if only is None:
256+
selected = atom_properties
257+
else:
258+
selected = {}
259+
for name in only:
260+
if name in atom_properties.keys():
261+
selected[name] = atom_properties[name]
262+
263+
# create property in the format expected by create_input
264+
atom_properties = {
265+
name: {"target": "atom", "values": np.stack(value, axis=0)}
266+
for name, value in selected.items()
267+
}
268+
_remove_invalid_properties(properties, "chemfiles")
269+
270+
for name, values in atom_properties.items():
271+
if name in properties:
272+
warnings.warn(
273+
f"a property named '{name}' is defined for both atoms and structures, "
274+
"the atom one will be ignored",
275+
stacklevel=2,
276+
)
277+
else:
278+
properties[name] = values
279+
280+
return properties

python/chemiscope/structures/_stk.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def _stk_valid_structures(
2222
elif (
2323
HAVE_STK
2424
and isinstance(structures, list)
25+
and len(structures) > 0
2526
and isinstance(structures[0], Molecule)
2627
):
2728
for structure in structures:

0 commit comments

Comments
 (0)