Source code for aiida_catmat.workchains.base

"""``BaseWorkChain`` to run a VASP calculation"""
from pymatgen.core.structure import Structure
from pymatgen.io.vasp import Oszicar

from aiida.common import AttributeDict
from aiida.engine import calcfunction
from aiida.engine import BaseRestartWorkChain, ProcessHandlerReport, process_handler, while_
from aiida.orm import Dict, FolderData, CalcJobNode, StructureData
from aiida.plugins import CalculationFactory

from aiida_catmat.calcfunctions import dict_merge
from aiida_catmat.parsers import STDERR_ERRS, STDOUT_ERRS

# StructureData = DataFactory('structure')  # pylint: disable=invalid-name
VaspCalculation = CalculationFactory('vasp.vasp')  # pylint: disable=invalid-name


[docs]@calcfunction def update_incar(incar: Dict, modifications: Dict) -> Dict: """Updates the current ``INCAR`` with proposed modifications. Args: incar (Dict): Current ``INCAR`` modifications (Dict): Proposed modifications to handle the error. Returns: Dict: The updated ``INCAR``. """ incar = incar.get_dict() modifications = modifications.get_dict() dict_merge(incar, modifications) return Dict(dict=incar)
[docs]@calcfunction def apply_strain_on_structure(retrived_folder: FolderData) -> StructureData: """Applies 0.2 strain on structure Args: retrived_folder (FolderData): The retrieved folder in the repository that contains ``CONTCAR`` Returns: StructureData: The structure after applying 0.2 strain. """ with retrived_folder.open('CONTCAR') as handler: structure = Structure.from_file(handler.name) structure.apply_strain(0.2) return StructureData(pymatgen_structure=structure)
[docs]def get_stdout_errs(calculation: CalcJobNode) -> set: """Parses the ``_scheduler-stdout.txt`` and searches for pre-defined error messages. Args: calculation (CalcJobNode): The calculation `Node`. Returns: set: A set of found error messages in ``_scheduler-stdout.txt`` """ errors = set() errors_subset_to_catch = list(STDOUT_ERRS.keys()) with calculation.outputs.retrieved.open('_scheduler-stdout.txt') as handler: for line in handler: l = line.strip() #pylint: disable=invalid-name for err, msgs in STDOUT_ERRS.items(): if err in errors_subset_to_catch: for msg in msgs: if l.find(msg) != -1: errors.add(err) return errors
[docs]def get_stderr_errs(calculation: CalcJobNode) -> set: """Parses the ``_scheduler-stderr.txt`` and searches for pre-defined error messages. Args: calculation (CalcJobNode): The calculation ``Node``. Returns: set: A set of found error messages in ``_scheduler-stderr.txt`` """ errors = set() errors_subset_to_catch = list(STDERR_ERRS.keys()) with calculation.outputs.retrieved.open('_scheduler-stderr.txt') as handler: for line in handler: l = line.strip() #pylint: disable=invalid-name for err, msgs in STDOUT_ERRS.items(): if err in errors_subset_to_catch: for msg in msgs: if l.find(msg) != -1: errors.add(err) return errors
#pylint: disable=inconsistent-return-statements #pylint: disable=too-many-public-methods
[docs]class VaspBaseWorkChain(BaseRestartWorkChain): """Workchain to run a ``VASP`` calculation with automated error handling and restarts.""" _process_class = VaspCalculation
[docs] @classmethod def define(cls, spec): super().define(spec) spec.expose_inputs(VaspCalculation, namespace='vasp') spec.outline( cls.setup, while_(cls.should_run_process)( cls.run_process, cls.inspect_process, ), cls.results, ) spec.expose_outputs(VaspCalculation)
[docs] def setup(self): """Call the ``setup`` of the ``BaseRestartWorkChain`` and then create the inputs dictionary in ``self.ctx.inputs``. This ``self.ctx.inputs`` dictionary will be used by the ``BaseRestartWorkChain`` to submit the calculations in the internal loop.""" super().setup() self.ctx.inputs = AttributeDict(self.exposed_inputs(VaspCalculation, 'vasp')) self.ctx.parameters = self.ctx.inputs.parameters self.ctx.modifications = {} self.ctx.err_count = {}
[docs] def report_error_handled(self, calculation: CalcJobNode, action: str) -> None: """Report an action taken for a calculation that has failed. Args: calculation (CalcJobNode): the failed calculation node action (str): a string message with the action taken """ arguments = [calculation.process_label, calculation.pk, calculation.exit_status, calculation.exit_message] self.report('{}<{}> failed with exit status {}: {}'.format(*arguments)) self.report('Action taken: {}'.format(action))
[docs] @process_handler(priority=570, enabled=True) def handle_timeout(self, calculation): """Error handler that restarts calculation finished with ``TIMEOUT`` ExitCode.""" self.ctx.stdout_errors = get_stdout_errs(calculation) self.ctx.stderr_errors = get_stderr_errs(calculation) if 'walltime' in self.ctx.stderr_errors: self.report_error_handled( calculation, 'Timeout handler. Adding remote folder as input to use binary restart.' ) self.ctx.inputs.restart_folder = calculation.outputs.remote_folder return ProcessHandlerReport(False)
[docs] @process_handler(priority=1, enabled=True) def apply_modifications(self, calculation): """Apply all requested modifications""" if self.ctx.modifications: self.ctx.inputs.parameters = update_incar(self.ctx.parameters, Dict(dict=self.ctx.modifications)) self.ctx.modifications = {} self.report('Applied all modifications for {}<{}>'.format(calculation.process_label, calculation.pk)) return ProcessHandlerReport(False)
[docs] @process_handler(priority=300, enabled=True) def handle_lreal(self, calculation): """Handle ``ERROR_LREAL_SMALL_SUPERCELL`` exit code""" if 'lreal' in self.ctx.stdout_errors: self.ctx.modifications.update({'LREAL': False}) action = 'ERROR_LREAL_SMALL_SUPERCELL: LREAL is set to False' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=310, enabled=True) def handle_rsphere(self, calculation): """Handle ``ERROR_RSPHERE`` exit code""" if 'rsphere' in self.ctx.stdout_errors: self.ctx.modifications.update({'LREAL': False}) action = 'ERROR_RSPHERE: LREAL is set to False' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=320, enabled=True) def handle_zbrent(self, calculation): """Handle ``ERROR_ZBRENT`` exit code""" if ('zbrent' in self.ctx.stdout_errors) and (self.ctx.err_count.get('zbrent', 0)): ediff = self.ctx.parameters.get_dict().get('EDIFF', 1e-6) * 0.01 self.ctx.modifications.update({'EDIFF': ediff}) self.ctx.err_count.update({'zbrent': 1}) action = f'ERROR_ZBRENT: EDIFF is decreased to {ediff}' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=100, enabled=False) def handle_tetrahedron(self, calculation): """Handle ``ERROR_TETRAHEDRON`` exit code""" if 'tet' in self.ctx.stdout_errors: if 'kspacing' in self.ctx.inputs: old_kspacing = self.ctx.inputs.kspacing.value new_kspacing = old_kspacing * 0.8 self.ctx.inputs.kspacing = new_kspacing action = f'ERROR_TETRAHEDRON: KSPACING is decreased by 80%: <{old_kspacing}> to <{new_kspacing}>' self.report_error_handled(calculation, action) else: self.ctx.modifications.update({'ISMEAR': 0, 'SIGMA': 0.05}) action = f'ERROR_TETRAHEDRON: ISMEAR is set to 0 and SIGMA 0.05' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=110, enabled=True) def handle_inverse_rotation_matrix(self, calculation): """Handle ``ERROR_INVERSE_ROTATION_MATRIX`` exit code""" if 'inv_rot_mat' in self.ctx.stdout_errors: self.ctx.modifications.update({'SYMPREC': 1e-8}) action = 'ERROR_INVERSE_ROTATION_MATRIX: SYMPREC is decreased to 1E-08' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=120, enabled=True) def handle_subspace_matrix(self, calculation): """Handle ``ERROR_SUBSPACEMATRIX`` exit code""" if 'subspacematrix' in self.ctx.stdout_errors: self.ctx.modifications.update({'LREAL': False}) action = 'ERROR_SUBSPACEMATRIX: LREAL is set to False' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=130, enabled=True) def handle_amin(self, calculation): """Handle ``ERROR_AMIN`` exit code""" if 'amin' in self.ctx.stdout_errors: self.ctx.modifications.update({'AMIN': 0.01}) action = 'ERROR_AMIN: AMIN is set to 0.01' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=140, enabled=True) def handle_pricel(self, calculation): """Handle ``ERROR_PRICEL`` exit code""" if 'pricel' in self.ctx.stdout_errors: self.ctx.modifications.update({'ISYM': 0, 'SYMPREC': 1e-8}) action = 'ERROR_PRICEL: ISYM is set to zero and SYMPREC to 1E-08' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=150, enabled=True) def handle_brions(self, calculation): """Handle ``ERROR_BRIONS`` exit code""" if 'brions' in self.ctx.stdout_errors: potim = self.ctx.parameters.get_dict().get('POTIM', 0.5) + 0.1 self.ctx.modifications.update({'POTIM': potim}) action = f'ERROR_BRIONS: POTIM is set to <{potim}>' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=160, enabled=True) def handle_pssyevx(self, calculation): """Handle ``ERROR_PSSYEVX`` exit code""" if 'pssyevx' in self.ctx.stdout_errors: self.ctx.modifications.update({'ALGO': 'Normal'}) action = 'ERROR_PSSYEVX: ALGO is set to Normal' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=170, enabled=True) def handle_eddrmm(self, calculation): """Handle ``ERROR_EDDRMM`` exit code""" if 'eddrmm' in self.ctx.stdout_errors: self.ctx.modifications.update({'ISTART': 0, 'ICHARG': 2}) if self.ctx.parameters['ALGO'] in ['Fast', 'VeryFast']: self.ctx.modifications.update({'ALGO': 'Normal'}) action = 'ERROR_EDDRMM: ALGO is set to Normal, ISTART to 0 and ICHARG to 2' else: potim = self.ctx.parameters.get_dict().get('POTIM', 0.5) / 2.0 self.ctx.modifications.update({'POTIM': potim}) action = f'ERROR_EDDRMM: POTIM is set to <{potim}>, ISTART to 0 and ICHARG to 2' self.ctx.inputs.restart_folder = calculation.outputs.remote_folder self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=180, enabled=True) def handle_edddav(self, calculation): """Handle ``ERROR_EDDDAV`` exit code""" if 'edddav' in self.ctx.stdout_errors: self.ctx.modifications.update({'ALGO': 'All', 'ISTART': 0, 'ICHARG': 2}) action = 'ERROR_EDDDAV: ALGO is set to Normal, ISTART to 0 and ICHARG to 2' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=190, enabled=True) def handle_grad_not_orth(self, calculation): """Handle ``ERROR_GRAD_NOT_ORTH`` exit code""" if 'grad_not_orth' in self.ctx.stdout_errors: self.ctx.modifications.update({'ISMEAR': 0, 'SIGMA': 0.05}) action = f'ERROR_GRAD_NOT_ORTH: ISMEAR is set to zero and SIGMA 0.05' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
# TODO: double check the solution #pylint: disable=fixme
[docs] @process_handler(priority=200, enabled=False) def handle_zheev(self, calculation): """Handle ``ERROR_ZHEEV`` exit code""" if 'zheev' in self.ctx.stdout_errors: self.ctx.modifications.update({'ALGO': 'Exact'}) action = 'ERROR_ZHEEV: ALGO is set to Exact' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=210, enabled=True) def handle_elf_kpar(self, calculation): """Handle ``ERROR_ELF_KPAR`` exit code""" if 'elf_kpar' in self.ctx.stdout_errors: self.ctx.modifications.update({'KPAR': 1}) action = 'ERROR_ELF_KPAR: KPAR is set to 1' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
# TODO: double check the solution. #pylint: disable=fixme
[docs] @process_handler(priority=220, enabled=False) def handle_rhosyg(self, calculation): """Handle ``ERROR_RHOSYG`` exit code""" if 'rhosyg' in self.ctx.stdout_errors: self.ctx.modifications.update({ 'ISYM': 0, 'SYMPREC': 1e-4 # ?? }) action = 'ERROR_RHOSYG: ISYM is set to 0 and SYMPREC to 1E-04' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
# TODO: double check the solution #pylint: disable=fixme
[docs] @process_handler(priority=230, enabled=False) def handle_posmap(self, calculation): """Handle ``ERROR_POSMAP`` exit code""" if 'posmap' in self.ctx.stdout_errors: self.ctx.modifications.update({'SYMPREC': 1e-6}) action = 'ERROR_POSMAP: SYMPREC is set to 1E-06' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=240, enabled=True) def handle_point_group(self, calculation): """Handle ``ERROR_POINT_GROUP`` exit code""" if 'point_group' in self.ctx.stdout_errors: self.ctx.modifications.update({'ISYM': 0}) action = 'ERROR_POINT_GROUP: ISYM is set to 0!' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
[docs] @process_handler(priority=250, enabled=True) def handle_zpotrf(self, calculation): """Handle ``ERROR_ZPOTRF`` exit code""" if 'zportf' in self.ctx.stdout_errors: try: with calculation.outputs.retrieved.open('OSZICAR') as handler: oszcar = Oszicar(handler.name) nsteps = len(oszcar.ionic_steps) except Exception: #pylint: disable=broad-except nsteps = 0 if nsteps >= 0: potim = self.ctx.parameters.get_dict().get('POTIM', 0.5) / 2.0 self.ctx.modifications.update({'ISYM': 0, 'POTIM': potim}) action = f'ERROR_ZPOTRF: ISYM is set to 0 and POTIM to {potim}!' elif self.ctx.parameters.get_dict().get('NSW', 0) == 0 or self.ctx.parameters.get_dict().get('ISIF', 0) in range(3): self.ctx.modifications.update({'ISYM': 0}) action = 'ERROR_ZPOTRF: ISYM is set to 0!' else: self.ctx.inputs.structure = apply_strain_on_structure(calculation.outputs.retrieved) action = 'ERROR_ZPOTRF: Applied 0.2 strain on the strcuture from CONTCAR' self.report_error_handled(calculation, action) return ProcessHandlerReport(False)
# EOF