Source code for janus_core.helpers.utils

"""Utility functions for janus_core."""

from __future__ import annotations

from abc import ABC
from collections.abc import Generator, Iterable, Sequence
from io import StringIO
from pathlib import Path
from typing import Any, Literal, TextIO

from ase import Atoms
from rich.progress import (
    BarColumn,
    MofNCompleteColumn,
    Progress,
    TextColumn,
    TimeRemainingColumn,
)
from rich.style import Style

from janus_core.helpers.janus_types import (
    MaybeSequence,
    PathLike,
    SliceLike,
    StartStopStep,
)


[docs] class FileNameMixin(ABC): # noqa: B024 (abstract-base-class-without-abstract-method) """ Provide mixin functions for standard filename handling. Parameters ---------- struct : MaybeSequence[Atoms] Structure from which to derive the default name. If `struct` is a sequence, the first structure will be used. struct_path : PathLike | None Path to file containing structures. file_prefix : PathLike | None Default prefix to use. *additional Components to add to default file_prefix (joined by hyphens). Methods ------- _get_default_prefix(file_prefix, struct) Return a prefix from the provided file_prefix or from chemical formula of struct. _build_filename(suffix, *additional, filename, prefix_override) Return a standard format filename if filename not provided. """
[docs] def __init__( self, struct: MaybeSequence[Atoms], struct_path: PathLike | None, file_prefix: PathLike | None, *additional, ) -> None: """ Provide mixin functions for standard filename handling. Parameters ---------- struct : MaybeSequence[Atoms] Structure(s) from which to derive the default name. If `struct` is a sequence, the first structure will be used. struct_path : PathLike | None Path to file structures were read from. Used as default prefix is not None. file_prefix : PathLike | None Default prefix to use. *additional Components to add to default file_prefix (joined by hyphens). """ self.file_prefix = Path( self._get_default_prefix(file_prefix, struct, struct_path, *additional) )
[docs] @staticmethod def _get_default_prefix( file_prefix: PathLike | None, struct: MaybeSequence[Atoms], struct_path: PathLike | None, *additional, ) -> str: """ Determine the default prefix from the structure or provided file_prefix. Parameters ---------- file_prefix : PathLike | None Given file_prefix. struct : MaybeSequence[Atoms] Structure(s) from which to derive the default name. If `struct` is a sequence, the first structure will be used. struct_path : PathLike | None Path to file containing structures. *additional Components to add to default file_prefix (joined by hyphens). Returns ------- str File prefix. """ if file_prefix is not None: return str(file_prefix) # Prefer file stem, otherwise use formula if struct_path is not None: struct_name = Path(struct_path).stem elif isinstance(struct, Sequence): struct_name = struct[0].get_chemical_formula() else: struct_name = struct.get_chemical_formula() return "-".join((struct_name, *filter(None, additional)))
[docs] def _build_filename( self, suffix: str, *additional, filename: PathLike | None = None, prefix_override: str | None = None, ) -> Path: """ Set filename using the file prefix and suffix if not specified otherwise. Parameters ---------- suffix : str Default suffix to use if `filename` is not specified. *additional Extra components to add to suffix (joined with hyphens). filename : PathLike | None Filename to use, if specified. Default is None. prefix_override : str | None Replace file_prefix if not None. Returns ------- Path Filename specified, or default filename. """ if filename: built_filename = Path(filename) else: prefix = str( prefix_override if prefix_override is not None else self.file_prefix ) built_filename = Path("-".join((prefix, *filter(None, additional), suffix))) # Make directory if it does not exist built_filename.parent.mkdir(parents=True, exist_ok=True) return built_filename
[docs] def none_to_dict(*dictionaries: Sequence[dict | None]) -> Generator[dict, None, None]: """ Ensure dictionaries that may be None are dictionaries. Parameters ---------- *dictionaries : Sequence[dict | None] Sequence of dictionaries that could be None. Yields ------ dict Input dictionaries or ``{}`` if empty or `None`. """ yield from (dictionary if dictionary else {} for dictionary in dictionaries)
[docs] def write_table( fmt: Literal["ascii", "csv"], file: TextIO | None = None, units: dict[str, str] | None = None, formats: dict[str, str] | None = None, *, print_header: bool = True, **columns, ) -> StringIO | None: """ Dump a table in a standard format. Table columns are passed as kwargs, with the column header being the keyword name and data the value. Each header also supports an optional "<header>_units" or "<header>_format" kwarg to define units and format for the column. These can also be passed explicitly through the respective dictionaries where the key is the "header". Parameters ---------- fmt : {'ascii', 'csv'} Format to write table in. file : TextIO File to dump to. If unspecified function returns io.StringIO object simulating file. units : dict[str, str] Units as ``{key: unit}``: key To align with those in `columns`. unit String defining unit. Units which do not match any columns in `columns` are ignored. formats : dict[str, str] Output formats as ``{key: format}``: key To align with those in `columns`. format Python magic format string to use. print_header : bool Whether to print the header or just append formatted data. **columns : dict[str, Any] Dictionary of columns names to data with optional "<header>_units"/"<header>_format" to define units/format. See: Examples. Returns ------- StringIO | None If no file given write columns to StringIO. Notes ----- Passing "kwarg_units" or "kwarg_format" takes priority over those passed in the `units`/`formats` dicts. Examples -------- >>> data = write_table(fmt="ascii", single=(1, 2, 3), ... double=(2,4,6), double_units="THz") >>> print(*data, sep="", end="") # single | double [THz] 1 2 2 4 3 6 >>> data = write_table(fmt="csv", a=(3., 5.), b=(11., 12.), ... units={'a': 'eV', 'b': 'Ha'}) >>> print(*data, sep="", end="") a [eV],b [Ha] 3.0,11.0 5.0,12.0 >>> data = write_table(fmt="csv", a=(3., 5.), b=(11., 12.), ... formats={"a": ".3f"}) >>> print(*data, sep="", end="") a,b 3.000,11.0 5.000,12.0 >>> data = write_table(fmt="ascii", single=(1, 2, 3), ... double=(2,4,6), double_units="THz", ... print_header=False) >>> print(*data, sep="", end="") 1 2 2 4 3 6 """ (units,) = none_to_dict(units) units.update( { key.removesuffix("_units"): val for key, val in columns.items() if key.endswith("_units") } ) (formats,) = none_to_dict(formats) formats.update( { key.removesuffix("_format"): val for key, val in columns.items() if key.endswith("_format") } ) columns = { key: val if isinstance(val, Iterable) else (val,) for key, val in columns.items() if not key.endswith("_units") } if print_header: header = [ f"{datum}" + (f" [{unit}]" if (unit := units.get(datum, "")) else "") for datum in columns ] else: header = () dump_loc = file if file is not None else StringIO() write_fmt = [formats.get(key, "") for key in columns] if fmt == "ascii": _dump_ascii(dump_loc, header, columns, write_fmt) elif fmt == "csv": _dump_csv(dump_loc, header, columns, write_fmt) if file is None: dump_loc.seek(0) return dump_loc return None
[docs] def _dump_ascii( file: TextIO, header: list[str], columns: dict[str, Sequence[Any]], formats: Sequence[str], ) -> None: """ Dump data as an ASCII style table. Parameters ---------- file : TextIO File to dump to. header : list[str] Column name information. columns : dict[str, Sequence[Any]] Column data by key (ordered with header info). formats : Sequence[str] Python magic string formats to apply (must align with header info). See Also -------- write_table : Main entry function. """ if header: print(f"# {' | '.join(header)}", file=file) for cols in zip(*columns.values()): print(*map(format, cols, formats), file=file)
[docs] def _dump_csv( file: TextIO, header: list[str], columns: dict[str, Sequence[Any]], formats: Sequence[str], ) -> None: """ Dump data as a csv style table. Parameters ---------- file : TextIO File to dump to. header : list[str] Column name information. columns : dict[str, Sequence[Any]] Column data by key (ordered with header info). formats : Sequence[str] Python magic string formats to apply (must align with header info). See Also -------- write_table : Main entry function. """ if header: print(",".join(header), file=file) for cols in zip(*columns.values()): print(",".join(map(format, cols, formats)), file=file)
[docs] def track_progress(sequence: Sequence | Iterable, description: str) -> Iterable: """ Track the progress of iterating over a sequence. This is done by displaying a progress bar in the console using the rich library. The function is an iterator over the sequence, updating the progress bar each iteration. Parameters ---------- sequence : Iterable The sequence to iterate over. Must support "len". description : str The text to display to the left of the progress bar. Yields ------ Iterable An iterable of the values in the sequence. """ text_column = TextColumn("{task.description}") bar_column = BarColumn( bar_width=None, complete_style=Style(color="#FBBB10"), finished_style=Style(color="#E38408"), ) completion_column = MofNCompleteColumn() time_column = TimeRemainingColumn() progress = Progress( text_column, bar_column, completion_column, time_column, expand=True, auto_refresh=False, ) with progress: yield from progress.track(sequence, description=description)
[docs] def check_files_exist(config: dict, req_file_keys: Sequence[PathLike]) -> None: """ Check files specified in a dictionary read from a configuration file exist. Parameters ---------- config : dict Dictionary read from configuration file. req_file_keys : Sequence[Pathlike] Files that must exist if defined in the configuration file. Raises ------ FileNotFoundError If a key from `req_file_keys` is in the configuration file, but the file corresponding to the configuration value do not exist. """ for file_key in config.keys() & req_file_keys: # Only check if file key is in the configuration file if not Path(config[file_key]).exists(): raise FileNotFoundError(f"{config[file_key]} does not exist")
[docs] def validate_slicelike(maybe_slicelike: SliceLike) -> None: """ Raise an exception if slc is not a valid SliceLike. Parameters ---------- maybe_slicelike : SliceLike Candidate to test. Raises ------ ValueError If maybe_slicelike is not SliceLike. """ if isinstance(maybe_slicelike, (slice, range, int)): return if isinstance(maybe_slicelike, tuple) and len(maybe_slicelike) == 3: start, stop, step = maybe_slicelike if ( (start is None or isinstance(start, int)) and (stop is None or isinstance(stop, int)) and isinstance(step, int) ): return raise ValueError(f"{maybe_slicelike} is not a valid SliceLike")
[docs] def slicelike_to_startstopstep(index: SliceLike) -> StartStopStep: """ Standarize `SliceLike`s into tuple of `start`, `stop`, `step`. Parameters ---------- index : SliceLike `SliceLike` to standardize. Returns ------- StartStopStep Standardized `SliceLike` as `start`, `stop`, `step` triplet. """ validate_slicelike(index) if isinstance(index, int): if index == -1: return (index, None, 1) return (index, index + 1, 1) if isinstance(index, (slice, range)): return (index.start, index.stop, index.step) return index
[docs] def selector_len(slc: SliceLike | list, selectable_length: int) -> int: """ Calculate the length of a selector applied to an indexable of a given length. Parameters ---------- slc : Union[SliceLike, list] The applied SliceLike or list for selection. selectable_length : int The length of the selectable object. Returns ------- int Length of the result of applying slc. """ if isinstance(slc, int): return 1 if isinstance(slc, list): return len(slc) start, stop, step = slicelike_to_startstopstep(slc) if stop is None: stop = selectable_length return len(range(start, stop, step))