Source code for aiida_mlip.data.model

"""Define Model Data type in AiiDA."""

import hashlib
from pathlib import Path
from typing import Any, Optional, Union
from urllib import request

from aiida.orm import QueryBuilder, SinglefileData, load_node


[docs] class ModelData(SinglefileData): """ Define Model Data type in AiiDA. Parameters ---------- file : Union[str, Path] Absolute path to the file. architecture : str Architecture of the mlip model. filename : Optional[str], optional Name to be used for the file (defaults to the name of provided file). Attributes ---------- architecture : str Architecture of the mlip model. model_hash : str Hash of the model. Methods ------- set_file(file, filename=None, architecture=None, **kwargs) Set the file for the node. from_local(file, architecture, filename=None): Create a ModelData instance from a local file. from_uri(uri, architecture, filename=None, cache_dir=None, keep_file=False) Download a file from a URI and save it as ModelData. Other Parameters ---------------- **kwargs : Any Additional keyword arguments. """
[docs] @staticmethod def _calculate_hash(file: Union[str, Path]) -> str: """ Calculate the hash of a file. Parameters ---------- file : Union[str, Path] Path to the file for which hash needs to be calculated. Returns ------- str The SHA-256 hash of the file. """ # Calculate hash buf_size = 65536 # reading 64kB (arbitrary) at a time sha256 = hashlib.sha256() with open(file, "rb") as f: # calculating sha in chunks rather than 1 large pass while data := f.read(buf_size): sha256.update(data) return sha256.hexdigest()
[docs] def __init__( self, file: Union[str, Path], architecture: str, filename: Optional[str] = None, **kwargs: Any, ) -> None: """ Initialize the ModelData object. Parameters ---------- file : Union[str, Path] Absolute path to the file. architecture : [str] Architecture of the mlip model. filename : Optional[str], optional Name to be used for the file (defaults to the name of provided file). Other Parameters ---------------- **kwargs : Any Additional keyword arguments. """ super().__init__(file, filename, **kwargs) self.base.attributes.set("architecture", architecture)
[docs] def set_file( self, file: Union[str, Path], filename: Optional[str] = None, architecture: Optional[str] = None, **kwargs: Any, ) -> None: """ Set the file for the node. Parameters ---------- file : Union[str, Path] Absolute path to the file. filename : Optional[str], optional Name to be used for the file (defaults to the name of provided file). architecture : Optional[str], optional Architecture of the mlip model. Other Parameters ---------------- **kwargs : Any Additional keyword arguments. """ super().set_file(file, filename, **kwargs) self.base.attributes.set("architecture", architecture) # here compute hash and set attribute model_hash = self._calculate_hash(file) self.base.attributes.set("model_hash", model_hash)
[docs] @classmethod def from_local( cls, file: Union[str, Path], architecture: str, filename: Optional[str] = None, ): """ Create a ModelData instance from a local file. Parameters ---------- file : Union[str, Path] Path to the file. architecture : [str] Architecture of the mlip model. filename : Optional[str], optional Name to be used for the file (defaults to the name of provided file). Returns ------- ModelData A ModelData instance. """ file_path = Path(file).resolve() return cls(file=file_path, architecture=architecture, filename=filename)
[docs] @classmethod def from_uri( cls, uri: str, architecture: str, filename: Optional[str] = "tmp_file.model", cache_dir: Optional[Union[str, Path]] = None, keep_file: Optional[bool] = False, ): """ Download a file from a URI and save it as ModelData. Parameters ---------- uri : str URI of the file to download. architecture : [str] Architecture of the mlip model. filename : Optional[str], optional Name to be used for the file defaults to tmp_file.model. cache_dir : Optional[Union[str, Path]], optional Path to the folder where the file has to be saved (defaults to "~/.cache/mlips/"). keep_file : Optional[bool], optional True to keep the downloaded model, even if there are duplicates. (default: False, the file is deleted and only saved in the database). Returns ------- ModelData A ModelData instance. """ cache_dir = ( Path(cache_dir) if cache_dir else Path("~/.cache/mlips/").expanduser() ) arch_dir = (cache_dir / architecture) if architecture else cache_dir arch_path = arch_dir.resolve() arch_path.mkdir(parents=True, exist_ok=True) file = arch_path / filename # Download file request.urlretrieve(uri, file) model = cls.from_local(file=file, architecture=architecture) if keep_file: return model file.unlink(missing_ok=True) # Check if the same model was used previously qb = QueryBuilder() qb.append( ModelData, filters={ "attributes.model_hash": model.model_hash, "attributes.architecture": model.architecture, "ctime": {"!in": [model.ctime]}, }, project=["attributes", "pk", "ctime"], ) if qb.count() != 0: model = load_node( qb.first()[1] ) # This gets the pk of the first model in the query return model
@property def architecture(self) -> str: """ Return the architecture. Returns ------- str Architecture of the mlip model. """ return self.base.attributes.get("architecture") @property def model_hash(self) -> str: """ Return hash of the architecture. Returns ------- str Hash of the MLIP model. """ return self.base.attributes.get("model_hash")