From 1059962c93c0d398f74cfe8ca741d5c9f2b919ae Mon Sep 17 00:00:00 2001 From: ymzhang0 Date: Tue, 20 Jan 2026 16:54:30 +0100 Subject: [PATCH] Fix bugs in EpwPrepWorkChain and update protocol - Fix logic for selecting Wannier90 workchain class (Optimize vs Bands) based on projection type and reference bands. - Ensure correct inputs are prepared for the selected Wannier90 workchain in `run_wannier90`. - Explicitly pass `structure` to `EpwBaseWorkChain` inputs. - Update exit code messages for better clarity. - Update protocol to include `aiida.mmn` and `aiida.bvec` in `remote_symlink_list`. --- src/aiida_epw/workflows/prep.py | 45 +++++++++++---------- src/aiida_epw/workflows/protocols/prep.yaml | 2 + 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/src/aiida_epw/workflows/prep.py b/src/aiida_epw/workflows/prep.py index 6b12474..987839a 100644 --- a/src/aiida_epw/workflows/prep.py +++ b/src/aiida_epw/workflows/prep.py @@ -188,22 +188,22 @@ def define(cls, spec): spec.exit_code( 403, "ERROR_SUB_PROCESS_FAILED_PHONON", - message="The electron-phonon `PhBaseWorkChain` sub process failed", + message="The `PhBaseWorkChain` subprocess failed", ) spec.exit_code( 404, "ERROR_SUB_PROCESS_FAILED_WANNIER90", - message="The `Wannier90BandsWorkChain` sub process failed", + message="The `Wannier90BandsWorkChain/Wannier90OptimizeWorkChain` subprocess failed", ) spec.exit_code( 405, "ERROR_SUB_PROCESS_FAILED_EPW", - message="The `EpwWorkChain` sub process failed", + message="The `EpwBaseWorkChain` subprocess failed", ) spec.exit_code( 406, "ERROR_SUB_PROCESS_FAILED_EPW_BANDS", - message="The `EpwBaseWorkChain` sub process failed", + message="The `EpwBaseWorkChain` for bands interpolation subprocess failed", ) @classmethod def get_protocol_filepath(cls): @@ -242,16 +242,13 @@ def get_builder_from_protocol( w90_bands_inputs = inputs.get("w90_bands", {}) pseudo_family = inputs.pop("pseudo_family", None) - if wannier_projection_type == WannierProjectionType.ATOMIC_PROJECTORS_QE: - if reference_bands is None: - raise ValueError( - f"reference_bands must be specified for {wannier_projection_type}" - ) + if reference_bands: w90_bands = Wannier90OptimizeWorkChain.get_builder_from_protocol( structure=structure, codes=codes, pseudo_family=pseudo_family, overrides=w90_bands_inputs, + projection_type=wannier_projection_type, reference_bands=reference_bands, bands_kpoints=bands_kpoints, ) @@ -259,17 +256,17 @@ def get_builder_from_protocol( # pop useless inputs, otherwise the builder validation will fail # at validating empty inputs w90_bands.pop("projwfc", None) - elif wannier_projection_type == WannierProjectionType.SCDM: + else: w90_bands = Wannier90BandsWorkChain.get_builder_from_protocol( structure=structure, codes=codes, pseudo_family=pseudo_family, overrides=w90_bands_inputs, + projection_type=wannier_projection_type, + bands_kpoints=bands_kpoints, ) - else: - raise ValueError( - f"Unsupported wannier_projection_type: {wannier_projection_type}" - ) + if wannier_projection_type == WannierProjectionType.ATOMIC_PROJECTORS_QE: + w90_bands.pop("projwfc", None) w90_bands.pop("structure", None) w90_bands.pop("open_grid", None) @@ -358,17 +355,20 @@ def generate_reciprocal_points(self): def run_wannier90(self): """Run the wannier90 workflow.""" - if "projwfc" in self.inputs.w90_bands: - w90_class = Wannier90BandsWorkChain - else: + inputs = AttributeDict( + self.exposed_inputs( + Wannier90OptimizeWorkChain, namespace="w90_bands" + ) + ) + if "reference_bands" in self.inputs.w90_bands: w90_class = Wannier90OptimizeWorkChain + # inputs.pop('projwfc') + else: + w90_class = Wannier90BandsWorkChain self.ctx.w90_class_name = w90_class.get_name() self.report(f"Running a {self.ctx.w90_class_name}.") - inputs = AttributeDict( - self.exposed_inputs(Wannier90OptimizeWorkChain, namespace="w90_bands") - ) inputs.metadata.call_link_label = "w90_bands" inputs.structure = self.inputs.structure @@ -426,9 +426,10 @@ def inspect_ph(self): def run_epw(self): """Run the `EpwBaseWorkChain`.""" inputs = AttributeDict( - self.exposed_inputs(EpwBaseWorkChain), namespace="epw_base" + self.exposed_inputs(EpwBaseWorkChain, namespace="epw_base") ) + inputs.structure = self.inputs.structure # The EpwBaseWorkChain will take the parent folder of the previous # PhCalculation, PwCalculation, and Wannier90Calculation. inputs.parent_folder_ph = self.ctx.workchain_ph.outputs.remote_folder @@ -476,7 +477,7 @@ def inspect_epw(self): return self.exit_codes.ERROR_SUB_PROCESS_FAILED_EPW def should_run_epw_bands(self): - """Check if the `EpwBaseWorkChain` should be run in bands mode.""" + """Check if the bands interpolation should be run.""" return "epw_bands" in self.inputs def run_epw_bands(self): diff --git a/src/aiida_epw/workflows/protocols/prep.yaml b/src/aiida_epw/workflows/protocols/prep.yaml index 4143b26..26e9eb4 100644 --- a/src/aiida_epw/workflows/protocols/prep.yaml +++ b/src/aiida_epw/workflows/protocols/prep.yaml @@ -41,6 +41,8 @@ default_inputs: - aiida.kgmap - aiida.kmap - aiida.ukk + - aiida.mmn + - aiida.bvec - out/aiida.epmatwp - save parameters: