Source code for janus_core.helpers.mlip_calculators

"""
Configure MLIP calculators.

Similar in spirit to matcalc and quacc approaches
- https://github.com/materialsvirtuallab/matcalc
- https://github.com/Quantum-Accelerators/quacc.git
"""

from __future__ import annotations

from os import environ
from pathlib import Path
from typing import TYPE_CHECKING, Any, get_args

from ase.calculators.mixing import SumCalculator

from janus_core.helpers.janus_types import Architectures, Devices, PathLike
from janus_core.helpers.utils import none_to_dict

if TYPE_CHECKING:
    from ase.calculators.calculator import Calculator
    import torch


[docs] def _set_model( model: PathLike | None = None, kwargs: dict[str, Any] | None = None, ) -> PathLike | torch.nn.Module | None: """ Set `model`. Parameters ---------- model Path to MLIP file. kwargs Dictionary of additional keyword arguments passed to the selected calculator. Returns ------- PathLike | torch.nn.Module | None Name of MLIP model, or path to MLIP model file or loaded model. """ (kwargs,) = none_to_dict(kwargs) # kwargs that may be used for `model` for different MLIPs # Note: "model" for chgnet (but not mace_mp or mace_off) and "potential" may refer # to loaded PyTorch models model_kwargs = { "model_path", "model_paths", "potential", "path", "model_name", "checkpoint_path", } present = kwargs.keys() & model_kwargs # Use model if specified, but check not also specified via kwargs if model and present: raise ValueError( "`model` cannot be used in combination with 'model_path', " "'model_paths', 'potential', or 'path'" ) if len(present) > 1: # Check at most one suitable kwarg is specified raise ValueError( "Only one of 'model_path', 'model_paths', 'potential', and 'path' can be " "specified" ) if present: # Set model from kwargs if any are specified model = kwargs.pop(present.pop()) # Convert to path if file/directory exists if isinstance(model, Path | str) and Path(model).expanduser().exists(): return Path(model).expanduser() return model
[docs] def _set_no_weights_only_load(): """Set environment variable to fix models for torch 2.6.""" environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1")
[docs] def choose_calculator( arch: Architectures, device: Devices = "cpu", model: PathLike | None = None, **kwargs, ) -> Calculator: """ Choose MLIP calculator to configure. Parameters ---------- arch MLIP architecture. device Device to run calculator on. Default is "cpu". model MLIP model label, path to model, or loaded model. Default is `None`. **kwargs Additional keyword arguments passed to the selected calculator. Returns ------- Calculator Configured MLIP calculator. Raises ------ ModuleNotFoundError MLIP module not correctly been installed. ValueError Invalid architecture specified. """ model = _set_model(model, kwargs) if device not in get_args(Devices): raise ValueError(f"`device` must be one of: {get_args(Devices)}") # Fix torch 2.6 (must be before MLIP modules are loaded) _set_no_weights_only_load() match arch: case "mace": from mace import __version__ from mace.calculators import MACECalculator # No default `model` if model is None: raise ValueError( f"Please specify `model`, as there is no default model for {arch}" ) # Default to float64 precision kwargs.setdefault("default_dtype", "float64") calculator = MACECalculator(model_paths=model, device=device, **kwargs) case "mace_mp": from mace import __version__ from mace.calculators import mace_mp # Default to "small" model and float64 precision model = model if model else "small" kwargs.setdefault("default_dtype", "float64") calculator = mace_mp(model=model, device=device, **kwargs) case "mace_off": from mace import __version__ from mace.calculators import mace_off # Default to "small" model and float64 precision model = model if model else "small" kwargs.setdefault("default_dtype", "float64") calculator = mace_off(model=model, device=device, **kwargs) case "m3gnet": from matgl import __version__, load_model from matgl.apps.pes import Potential from matgl.ext.ase import M3GNetCalculator import torch # Set before loading model to avoid type mismatches torch.set_default_dtype(torch.float32) kwargs.setdefault("stress_weight", 1.0 / 160.21766208) # Use potential (from kwargs) if specified # Otherwise, load the model if given a path, else use a default model match model: case Potential(): potential = model model = "loaded_Potential" case Path(): if model.is_file(): model = model.parent potential = load_model(model) case str(): potential = load_model(model) case _: model = "M3GNet-MP-2021.2.8-DIRECT-PES" potential = load_model(model) calculator = M3GNetCalculator(potential=potential, **kwargs) case "chgnet": from chgnet import __version__ from chgnet.model.dynamics import CHGNetCalculator from chgnet.model.model import CHGNet import torch # Set before loading to avoid type mismatches torch.set_default_dtype(torch.float32) # Use loaded model (from kwargs) if specified # Otherwise, load the model if given a path, else use a default model match model: case CHGNet(): loaded_model = model model = "loaded_CHGNet" case Path(): loaded_model = CHGNet.from_file(model) case str(): loaded_model = CHGNet.load(model_name=model, use_device=device) case _: model = __version__ loaded_model = None calculator = CHGNetCalculator( model=loaded_model, use_device=device, **kwargs ) case "alignn": from alignn import __version__ from alignn.ff.ff import ( AlignnAtomwiseCalculator, default_path, get_figshare_model_ff, ) # Set default path to directory containing config and model location match model: case Path(): if model.is_file(): model = model.parent # If a string, assume referring to model_name e.g. "v5.27.2024" case str(): model = get_figshare_model_ff(model_name=model) case _: model = default_path() calculator = AlignnAtomwiseCalculator(path=model, device=device, **kwargs) case "sevennet": from sevenn import __version__ from sevenn.sevennet_calculator import SevenNetCalculator import torch # Set before loading model to avoid type mismatches torch.set_default_dtype(torch.float32) match model: case Path() | str(): model = str(model) case _: model = "SevenNet-0_11July2024" kwargs.setdefault("file_type", "checkpoint") kwargs.setdefault("sevennet_config", None) calculator = SevenNetCalculator(model=model, device=device, **kwargs) case "nequip": from nequip import __version__ from nequip.ase import NequIPCalculator # No default `model` if model is None: raise ValueError( f"Please specify `model`, as there is no default model for {arch}" ) model = str(model) calculator = NequIPCalculator.from_deployed_model( model_path=model, device=device, **kwargs ) case "dpa3": from deepmd import __version__ from deepmd.calculator import DP # No default `model` if model is None: # From https://matbench-discovery.materialsproject.org/models/dpa3-v1-mptrj raise ValueError( "Please specify `model`, as there is no " f"default model for {arch} " "e.g. https://bohrium-api.dp.tech/ds-dl/dpa3openlam-74ng-v3.zip" ) model = str(model) calculator = DP(model=model, **kwargs) case "orb": from orb_models import __version__ from orb_models.forcefield.calculator import ORBCalculator from orb_models.forcefield.direct_regressor import DirectForcefieldRegressor import orb_models.forcefield.pretrained as orb_ff match model: case DirectForcefieldRegressor(): loaded_model = model model = "loaded_DirectForcefieldRegressor" case str() if hasattr(orb_ff, model.replace("-", "_")): loaded_model = getattr(orb_ff, model.replace("-", "_"))() case None: # Default model model = "orb_v3_conservative_20_omat" loaded_model = getattr(orb_ff, model)() case _: raise ValueError( "`model` must be a `DirectForcefieldRegressor`, pre-trained " "model label (e.g. 'orb-v2'), or `None` (uses default, orb-v2)" ) calculator = ORBCalculator(model=loaded_model, device=device, **kwargs) case "mattersim": from mattersim import __version__ from mattersim.forcefield import MatterSimCalculator from torch.nn import Module potential = None match model: case Module(): potential = model model = "loaded_Module" case Path() | str(): model = str(model) case None: model = "mattersim-v1.0.0-5M" calculator = MatterSimCalculator( potential=potential, load_path=model, device=device, **kwargs ) case "grace": from tensorpotential.calculator import grace_fm __version__ = "0.5.1" # Default model model = model if model else "GRACE-2L-OMAT" if isinstance(model, Path): model = str(model) calculator = grace_fm(model, **kwargs) case "equiformer" | "esen": from fairchem.core import OCPCalculator, __version__ match arch, model: case ("equiformer", None): model = "EquiformerV2-31M-S2EF-OC20-All+MD" case ("esen", None): model = "eSEN-30M-OMAT24" case _: pass model_name = None checkpoint_path = None if isinstance(model, Path) and model.exists(): checkpoint_path = str(model) else: model_name = str(model) kwargs.setdefault("local_cache", Path("~/.cache/fairchem").expanduser()) cpu = True if device == "cpu" else False calculator = OCPCalculator( model_name=model_name, checkpoint_path=checkpoint_path, cpu=cpu, **kwargs, ) case _: raise ValueError( f"Unrecognized {arch=}. Suported architectures " f"are {', '.join(Architectures.__args__)}" ) if isinstance(model, Path): model = model.as_posix() calculator.parameters["version"] = __version__ calculator.parameters["arch"] = arch calculator.parameters["model"] = str(model) return calculator
[docs] def check_calculator(calc: Calculator, attribute: str) -> None: """ Ensure calculator has ability to calculate properties. If the calculator is a SumCalculator that inlcudes the TorchDFTD3Calculator, this also sets the relevant function so that the MLIP component of the calculator is used for properties unrelated to dispersion. Parameters ---------- calc ASE Calculator to check. attribute Attribute to check calculator for. """ # If dispersion added to MLIP calculator, use only MLIP calculator for calculation if ( isinstance(calc, SumCalculator) and len(calc.mixer.calcs) == 2 and calc.mixer.calcs[1].name == "TorchDFTD3Calculator" and hasattr(calc.mixer.calcs[0], attribute) ): setattr(calc, attribute, getattr(calc.mixer.calcs[0], attribute)) if not hasattr(calc, attribute) or not callable(getattr(calc, attribute)): raise NotImplementedError( f"The attached calculator does not currently support {attribute}" )