Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions src/aiida_epw/workflows/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,22 +188,22 @@ def define(cls, spec):
spec.exit_code(
403,
"ERROR_SUB_PROCESS_FAILED_PHONON",
message="The electron-phonon `PhBaseWorkChain` sub process failed",
message="The `PhBaseWorkChain` subprocess failed",
)
spec.exit_code(
404,
"ERROR_SUB_PROCESS_FAILED_WANNIER90",
message="The `Wannier90BandsWorkChain` sub process failed",
message="The `Wannier90BandsWorkChain/Wannier90OptimizeWorkChain` subprocess failed",
)
spec.exit_code(
405,
"ERROR_SUB_PROCESS_FAILED_EPW",
message="The `EpwWorkChain` sub process failed",
message="The `EpwBaseWorkChain` subprocess failed",
)
spec.exit_code(
406,
"ERROR_SUB_PROCESS_FAILED_EPW_BANDS",
message="The `EpwBaseWorkChain` sub process failed",
message="The `EpwBaseWorkChain` for bands interpolation subprocess failed",
)
@classmethod
def get_protocol_filepath(cls):
Expand Down Expand Up @@ -242,34 +242,31 @@ def get_builder_from_protocol(
w90_bands_inputs = inputs.get("w90_bands", {})
pseudo_family = inputs.pop("pseudo_family", None)

if wannier_projection_type == WannierProjectionType.ATOMIC_PROJECTORS_QE:
if reference_bands is None:
raise ValueError(
f"reference_bands must be specified for {wannier_projection_type}"
)
if reference_bands:
w90_bands = Wannier90OptimizeWorkChain.get_builder_from_protocol(
structure=structure,
codes=codes,
pseudo_family=pseudo_family,
overrides=w90_bands_inputs,
projection_type=wannier_projection_type,
reference_bands=reference_bands,
bands_kpoints=bands_kpoints,
)
w90_bands.separate_plotting = False
# pop useless inputs, otherwise the builder validation will fail
# at validating empty inputs
w90_bands.pop("projwfc", None)
elif wannier_projection_type == WannierProjectionType.SCDM:
else:
w90_bands = Wannier90BandsWorkChain.get_builder_from_protocol(
structure=structure,
codes=codes,
pseudo_family=pseudo_family,
overrides=w90_bands_inputs,
projection_type=wannier_projection_type,
bands_kpoints=bands_kpoints,
)
else:
raise ValueError(
f"Unsupported wannier_projection_type: {wannier_projection_type}"
)
if wannier_projection_type == WannierProjectionType.ATOMIC_PROJECTORS_QE:
w90_bands.pop("projwfc", None)

w90_bands.pop("structure", None)
w90_bands.pop("open_grid", None)
Expand Down Expand Up @@ -358,17 +355,20 @@ def generate_reciprocal_points(self):

def run_wannier90(self):
"""Run the wannier90 workflow."""
if "projwfc" in self.inputs.w90_bands:
w90_class = Wannier90BandsWorkChain
else:
inputs = AttributeDict(
self.exposed_inputs(
Wannier90OptimizeWorkChain, namespace="w90_bands"
)
)
if "reference_bands" in self.inputs.w90_bands:
w90_class = Wannier90OptimizeWorkChain
# inputs.pop('projwfc')
else:
w90_class = Wannier90BandsWorkChain

self.ctx.w90_class_name = w90_class.get_name()
self.report(f"Running a {self.ctx.w90_class_name}.")

inputs = AttributeDict(
self.exposed_inputs(Wannier90OptimizeWorkChain, namespace="w90_bands")
)
inputs.metadata.call_link_label = "w90_bands"
inputs.structure = self.inputs.structure

Expand Down Expand Up @@ -426,9 +426,10 @@ def inspect_ph(self):
def run_epw(self):
"""Run the `EpwBaseWorkChain`."""
inputs = AttributeDict(
self.exposed_inputs(EpwBaseWorkChain), namespace="epw_base"
self.exposed_inputs(EpwBaseWorkChain, namespace="epw_base")
)

inputs.structure = self.inputs.structure
# The EpwBaseWorkChain will take the parent folder of the previous
# PhCalculation, PwCalculation, and Wannier90Calculation.
inputs.parent_folder_ph = self.ctx.workchain_ph.outputs.remote_folder
Expand Down Expand Up @@ -476,7 +477,7 @@ def inspect_epw(self):
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_EPW

def should_run_epw_bands(self):
"""Check if the `EpwBaseWorkChain` should be run in bands mode."""
"""Check if the bands interpolation should be run."""
return "epw_bands" in self.inputs

def run_epw_bands(self):
Expand Down
2 changes: 2 additions & 0 deletions src/aiida_epw/workflows/protocols/prep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ default_inputs:
- aiida.kgmap
- aiida.kmap
- aiida.ukk
- aiida.mmn
- aiida.bvec
- out/aiida.epmatwp
- save
parameters:
Expand Down