|
| 1 | +import warnings |
| 2 | +from typing import Sequence |
| 3 | + |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from ._ase import _remove_invalid_properties |
| 7 | + |
| 8 | + |
| 9 | +try: |
| 10 | + import chemfiles |
| 11 | + |
| 12 | + if chemfiles.__version__ < "0.10.0" or chemfiles.__version__ > "0.11.0": |
| 13 | + print( |
| 14 | + "chemiscope requires chemfiles version >=0.10,<0.11; " |
| 15 | + f"but version {chemfiles.__version__} is installed." |
| 16 | + ) |
| 17 | + |
| 18 | + from chemfiles import Frame, MemoryTrajectory, Trajectory |
| 19 | + |
| 20 | + HAVE_CHEMFILES = True |
| 21 | +except ImportError: |
| 22 | + HAVE_CHEMFILES = False |
| 23 | + |
| 24 | + |
| 25 | +def _chemfiles_valid_structures(structures): |
| 26 | + if not HAVE_CHEMFILES: |
| 27 | + return structures, False |
| 28 | + |
| 29 | + if isinstance(structures, (Trajectory, MemoryTrajectory)): |
| 30 | + return [structures.read_step(i) for i in range(structures.nsteps)], True |
| 31 | + |
| 32 | + elif isinstance(structures, Frame): |
| 33 | + return [structures], True |
| 34 | + |
| 35 | + elif ( |
| 36 | + isinstance(structures, Sequence) |
| 37 | + and len(structures) > 0 |
| 38 | + and isinstance(structures[0], Frame) |
| 39 | + ): |
| 40 | + for structure in structures: |
| 41 | + assert isinstance(structure, Frame) |
| 42 | + return structures, True |
| 43 | + |
| 44 | + else: |
| 45 | + return structures, False |
| 46 | + |
| 47 | + |
| 48 | +# fmt: off |
| 49 | +ELEMENTS = [ |
| 50 | + "X", "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", |
| 51 | + "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", |
| 52 | + "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", |
| 53 | + "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", |
| 54 | + "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", |
| 55 | + "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", |
| 56 | + "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", |
| 57 | + "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", |
| 58 | + "Nh", "Fl", "Mc", "Lv", "Ts", "Og" |
| 59 | +] |
| 60 | +# fmt: on |
| 61 | + |
| 62 | + |
| 63 | +def _chemfiles_to_json(frame): |
| 64 | + """Implementation of structures_to_json for chemfiles' ``Frame``.""" |
| 65 | + |
| 66 | + BOND_ORDERS_TO_NUMERIC = { |
| 67 | + chemfiles.BondOrder.Unknown: 1, |
| 68 | + chemfiles.BondOrder.Single: 1, |
| 69 | + chemfiles.BondOrder.Double: 2, |
| 70 | + chemfiles.BondOrder.Triple: 3, |
| 71 | + # 3Dmol seems to cap bond order at 3 |
| 72 | + chemfiles.BondOrder.Quadruple: 3, |
| 73 | + chemfiles.BondOrder.Quintuplet: 3, |
| 74 | + chemfiles.BondOrder.Amide: 1, |
| 75 | + chemfiles.BondOrder.Aromatic: 4, |
| 76 | + } |
| 77 | + |
| 78 | + data = {} |
| 79 | + data["size"] = len(frame.atoms) |
| 80 | + data["names"] = [atom.name for atom in frame.atoms] |
| 81 | + |
| 82 | + elements = [] |
| 83 | + all_have_element = True |
| 84 | + for atom in frame.atoms: |
| 85 | + atomic_number = atom.atomic_number |
| 86 | + if atomic_number == 0: |
| 87 | + all_have_element = False |
| 88 | + break |
| 89 | + else: |
| 90 | + elements.append(ELEMENTS[atomic_number]) |
| 91 | + |
| 92 | + if all_have_element: |
| 93 | + data["elements"] = elements |
| 94 | + |
| 95 | + # data["elements"] = TODO |
| 96 | + positions = frame.positions |
| 97 | + data["x"] = [float(positions[i][0]) for i in range(data["size"])] |
| 98 | + data["y"] = [float(positions[i][1]) for i in range(data["size"])] |
| 99 | + data["z"] = [float(positions[i][2]) for i in range(data["size"])] |
| 100 | + |
| 101 | + if frame.cell.shape != chemfiles.CellShape.Infinite: |
| 102 | + data["cell"] = frame.cell.matrix.T.flatten().tolist() |
| 103 | + |
| 104 | + # bonds |
| 105 | + topology = frame.topology |
| 106 | + if len(topology.bonds) > 0: |
| 107 | + data["bonds"] = [ |
| 108 | + (int(bond[0]), int(bond[1]), BOND_ORDERS_TO_NUMERIC[order]) |
| 109 | + for bond, order in zip(topology.bonds, topology.bonds_orders, strict=True) |
| 110 | + ] |
| 111 | + |
| 112 | + # biomolecule-specific information |
| 113 | + chains = [] |
| 114 | + resnames = [] |
| 115 | + resids = [] |
| 116 | + hetatom = [] |
| 117 | + has_biomol_info = False |
| 118 | + for atom_i in range(data["size"]): |
| 119 | + residue = frame.topology.residue_for_atom(atom_i) |
| 120 | + if residue is None: |
| 121 | + continue |
| 122 | + |
| 123 | + has_biomol_info = True |
| 124 | + resids.append(residue.id) |
| 125 | + resnames.append(residue.name) |
| 126 | + |
| 127 | + residue_properties = residue.list_properties() |
| 128 | + if "chainname" in residue_properties: |
| 129 | + chains.append(residue["chainname"]) |
| 130 | + else: |
| 131 | + chains.append("") |
| 132 | + |
| 133 | + if "is_standard_pdb" in residue_properties: |
| 134 | + hetatom.append(not residue["is_standard_pdb"]) |
| 135 | + else: |
| 136 | + hetatom.append(True) |
| 137 | + |
| 138 | + if has_biomol_info: |
| 139 | + data["chains"] = chains |
| 140 | + data["resnames"] = resnames |
| 141 | + data["resids"] = resids |
| 142 | + data["hetatom"] = hetatom |
| 143 | + |
| 144 | + return data |
| 145 | + |
| 146 | + |
| 147 | +def _chemfiles_all_atomic_environments(structures, cutoff): |
| 148 | + environments = [] |
| 149 | + for frame_i, frame in enumerate(structures): |
| 150 | + for atom_i in range(len(frame.atoms)): |
| 151 | + environments.append((frame_i, atom_i, cutoff)) |
| 152 | + |
| 153 | + return environments |
| 154 | + |
| 155 | + |
| 156 | +def _chemfiles_get_structure_properties(frames): |
| 157 | + # extract the set of common properties between all frames |
| 158 | + all_properties = {} |
| 159 | + extra = set() |
| 160 | + |
| 161 | + for name in frames[0].list_properties(): |
| 162 | + all_properties[name] = [frames[0][name]] |
| 163 | + |
| 164 | + for frame in frames[1:]: |
| 165 | + current_properties = frame.list_properties() |
| 166 | + for name in current_properties: |
| 167 | + if name in all_properties: |
| 168 | + all_properties[name].append(frame[name]) |
| 169 | + else: |
| 170 | + extra.add(name) |
| 171 | + |
| 172 | + for name in list(all_properties.keys()): |
| 173 | + if name not in current_properties: |
| 174 | + all_properties.pop(name, None) |
| 175 | + extra.add(name) |
| 176 | + |
| 177 | + if len(extra) != 0: |
| 178 | + warnings.warn( |
| 179 | + "the following structure properties are only defined for a subset " |
| 180 | + f"of structures: {list(sorted(extra))}; they will be ignored", |
| 181 | + stacklevel=2, |
| 182 | + ) |
| 183 | + |
| 184 | + # ensures that if a property is a mix of strings and numbers, everything |
| 185 | + # is converted to string (as these should be categorical properties) |
| 186 | + for name in all_properties.keys(): |
| 187 | + property = all_properties[name] |
| 188 | + if any(isinstance(x, str) for x in property): |
| 189 | + all_properties[name] = [str(x) for x in property] |
| 190 | + |
| 191 | + return all_properties |
| 192 | + |
| 193 | + |
| 194 | +def _chemfiles_get_atom_properties(frames, environments): |
| 195 | + assert environments is not None |
| 196 | + # extract the set of common properties between all atoms in all frames |
| 197 | + all_properties = {} |
| 198 | + extra = set() |
| 199 | + |
| 200 | + frame = frames[environments[0][0]] |
| 201 | + atom = frame.atoms[environments[0][1]] |
| 202 | + for name in atom.list_properties(): |
| 203 | + all_properties[name] = [atom[name]] |
| 204 | + |
| 205 | + for frame_i, atom_i, _ in environments[1:]: |
| 206 | + frame = frames[frame_i] |
| 207 | + atom = frame.atoms[atom_i] |
| 208 | + |
| 209 | + current_properties = atom.list_properties() |
| 210 | + for name in current_properties: |
| 211 | + if name in all_properties: |
| 212 | + all_properties[name].append(atom[name]) |
| 213 | + else: |
| 214 | + extra.add(name) |
| 215 | + |
| 216 | + for name in list(all_properties.keys()): |
| 217 | + if name not in current_properties: |
| 218 | + all_properties.pop(name, None) |
| 219 | + extra.add(name) |
| 220 | + |
| 221 | + if len(extra) != 0: |
| 222 | + warnings.warn( |
| 223 | + "the following atomic properties are only defined for a subset " |
| 224 | + f"of structures: {list(sorted(extra))}; they will be ignored", |
| 225 | + stacklevel=2, |
| 226 | + ) |
| 227 | + |
| 228 | + return all_properties |
| 229 | + |
| 230 | + |
| 231 | +def _chemfiles_extract_properties(structures, only=None, environments=None): |
| 232 | + """implementation of ``extract_properties`` for chemfiles""" |
| 233 | + all_properties = _chemfiles_get_structure_properties(structures) |
| 234 | + if only is None: |
| 235 | + selected = all_properties |
| 236 | + else: |
| 237 | + selected = {} |
| 238 | + for name in only: |
| 239 | + if name in all_properties.keys(): |
| 240 | + selected[name] = all_properties[name] |
| 241 | + |
| 242 | + # create property in the format expected by create_input |
| 243 | + properties = { |
| 244 | + name: {"target": "structure", "values": np.asarray(value)} |
| 245 | + for name, value in selected.items() |
| 246 | + } |
| 247 | + |
| 248 | + _remove_invalid_properties(properties, "chemfiles") |
| 249 | + |
| 250 | + if environments is None: |
| 251 | + environments = _chemfiles_all_atomic_environments(structures, cutoff=0.0) |
| 252 | + |
| 253 | + atom_properties = _chemfiles_get_atom_properties(structures, environments) |
| 254 | + |
| 255 | + if only is None: |
| 256 | + selected = atom_properties |
| 257 | + else: |
| 258 | + selected = {} |
| 259 | + for name in only: |
| 260 | + if name in atom_properties.keys(): |
| 261 | + selected[name] = atom_properties[name] |
| 262 | + |
| 263 | + # create property in the format expected by create_input |
| 264 | + atom_properties = { |
| 265 | + name: {"target": "atom", "values": np.stack(value, axis=0)} |
| 266 | + for name, value in selected.items() |
| 267 | + } |
| 268 | + _remove_invalid_properties(properties, "chemfiles") |
| 269 | + |
| 270 | + for name, values in atom_properties.items(): |
| 271 | + if name in properties: |
| 272 | + warnings.warn( |
| 273 | + f"a property named '{name}' is defined for both atoms and structures, " |
| 274 | + "the atom one will be ignored", |
| 275 | + stacklevel=2, |
| 276 | + ) |
| 277 | + else: |
| 278 | + properties[name] = values |
| 279 | + |
| 280 | + return properties |
0 commit comments