Source code for janus_core.helpers.mlip_calculators

"""
Configure MLIP calculators.

Similar in spirit to matcalc and quacc approaches
- https://github.com/materialsvirtuallab/matcalc
- https://github.com/Quantum-Accelerators/quacc.git
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, get_args

from ase.calculators.mixing import SumCalculator

from janus_core.helpers.janus_types import Architectures, Devices, PathLike
from janus_core.helpers.utils import none_to_dict

if TYPE_CHECKING:
    from ase.calculators.calculator import Calculator
    import torch


[docs] def _set_model_path( model_path: PathLike | None = None, kwargs: dict[str, Any] | None = None, ) -> PathLike | torch.nn.Module | None: """ Set `model_path`. Parameters ---------- model_path : PathLike | None Path to MLIP file. kwargs : dict[str, Any] | None Dictionary of additional keyword arguments passed to the selected calculator. Returns ------- PathLike | torch.nn.Module | None Path to MLIP model file, loaded model, or None. """ (kwargs,) = none_to_dict(kwargs) # kwargs that may be used for `model_path`` for different MLIPs # Note: "model" for chgnet (but not mace_mp or mace_off) and "potential" may refer # to loaded PyTorch models present = kwargs.keys() & {"model", "model_paths", "potential", "path"} # Use model_path if specified, but check not also specified via kwargs if model_path and present: raise ValueError( "`model_path` cannot be used in combination with 'model', " "'model_paths', 'potential', or 'path'" ) if len(present) > 1: # Check at most one suitable kwarg is specified raise ValueError( "Only one of 'model', 'model_paths', 'potential', and 'path' can be " "specified" ) if present: # Set model_path from kwargs if any are specified model_path = kwargs.pop(present.pop()) # Convert to path if file/directory exists if isinstance(model_path, (Path, str)) and Path(model_path).expanduser().exists(): return Path(model_path).expanduser() return model_path
[docs] def choose_calculator( arch: Architectures = "mace", device: Devices = "cpu", model_path: PathLike | None = None, **kwargs, ) -> Calculator: """ Choose MLIP calculator to configure. Parameters ---------- arch : Architectures MLIP architecture. Default is "mace". device : Devices Device to run calculator on. Default is "cpu". model_path : PathLike | None Path to MLIP file. **kwargs Additional keyword arguments passed to the selected calculator. Returns ------- Calculator Configured MLIP calculator. Raises ------ ModuleNotFoundError MLIP module not correctly been installed. ValueError Invalid architecture specified. """ model_path = _set_model_path(model_path, kwargs) if device not in get_args(Devices): raise ValueError(f"`device` must be one of: {get_args(Devices)}") if arch == "mace": from mace import __version__ from mace.calculators import MACECalculator # No default `model_path` if model_path is None: raise ValueError( "Please specify `model_path`, as there is no " f"default model for {arch}" ) # Default to float64 precision kwargs.setdefault("default_dtype", "float64") calculator = MACECalculator(model_paths=model_path, device=device, **kwargs) elif arch == "mace_mp": from mace import __version__ from mace.calculators import mace_mp # Default to "small" model and float64 precision model_path = model_path if model_path else "small" kwargs.setdefault("default_dtype", "float64") calculator = mace_mp(model=model_path, device=device, **kwargs) elif arch == "mace_off": from mace import __version__ from mace.calculators import mace_off # Default to "small" model and float64 precision model_path = model_path if model_path else "small" kwargs.setdefault("default_dtype", "float64") calculator = mace_off(model=model_path, device=device, **kwargs) elif arch == "m3gnet": from matgl import __version__, load_model from matgl.apps.pes import Potential from matgl.ext.ase import M3GNetCalculator import torch # Set before loading model to avoid type mismatches torch.set_default_dtype(torch.float32) kwargs.setdefault("stress_weight", 1.0 / 160.21766208) # Use potential (from kwargs) if specified # Otherwise, load the model if given a path, else use a default model if isinstance(model_path, Potential): potential = model_path model_path = "loaded_Potential" elif isinstance(model_path, Path): if model_path.is_file(): model_path = model_path.parent potential = load_model(model_path) elif isinstance(model_path, str): potential = load_model(model_path) else: model_path = "M3GNet-MP-2021.2.8-DIRECT-PES" potential = load_model(model_path) calculator = M3GNetCalculator(potential=potential, **kwargs) elif arch == "chgnet": from chgnet import __version__ from chgnet.model.dynamics import CHGNetCalculator from chgnet.model.model import CHGNet import torch # Set before loading to avoid type mismatches torch.set_default_dtype(torch.float32) # Use loaded model (from kwargs) if specified # Otherwise, load the model if given a path, else use a default model if isinstance(model_path, CHGNet): model = model_path model_path = "loaded_CHGNet" elif isinstance(model_path, Path): model = CHGNet.from_file(model_path) elif isinstance(model_path, str): model = CHGNet.load(model_name=model_path, use_device=device) else: model_path = "0.3.0" model = None calculator = CHGNetCalculator(model=model, use_device=device, **kwargs) elif arch == "alignn": from alignn import __version__ from alignn.ff.ff import ( AlignnAtomwiseCalculator, default_path, get_figshare_model_ff, ) # Set default path to directory containing config and model location if isinstance(model_path, Path): if model_path.is_file(): model_path = model_path.parent # If a string, assume referring to model_name e.g. "v5.27.2024" elif isinstance(model_path, str): model_path = get_figshare_model_ff(model_name=model_path) else: model_path = default_path() calculator = AlignnAtomwiseCalculator(path=model_path, device=device, **kwargs) elif arch == "sevennet": from sevenn import __version__ from sevenn.sevennet_calculator import SevenNetCalculator if isinstance(model_path, Path): model_path = str(model_path) elif not isinstance(model_path, str): model_path = "SevenNet-0_11July2024" kwargs.setdefault("file_type", "checkpoint") kwargs.setdefault("sevennet_config", None) calculator = SevenNetCalculator(model=model_path, device=device, **kwargs) else: raise ValueError( f"Unrecognized {arch=}. Suported architectures " f"are {', '.join(Architectures.__args__)}" ) calculator.parameters["version"] = __version__ calculator.parameters["arch"] = arch calculator.parameters["model_path"] = str(model_path) return calculator
[docs] def check_calculator(calc: Calculator, attribute: str) -> None: """ Ensure calculator has ability to calculate properties. If the calculator is a SumCalculator that inlcudes the TorchDFTD3Calculator, this also sets the relevant function so that the MLIP component of the calculator is used for properties unrelated to dispersion. Parameters ---------- calc : Calculator ASE Calculator to check. attribute : str Attribute to check calculator for. """ # If dispersion added to MLIP calculator, use only MLIP calculator for calculation if ( isinstance(calc, SumCalculator) and len(calc.mixer.calcs) == 2 and calc.mixer.calcs[1].name == "TorchDFTD3Calculator" and hasattr(calc.mixer.calcs[0], attribute) ): setattr(calc, attribute, getattr(calc.mixer.calcs[0], attribute)) if not hasattr(calc, attribute) or not callable(getattr(calc, attribute)): raise NotImplementedError( f"The attached calculator does not currently support {attribute}" )