"""Utility functions for CLI."""
from __future__ import annotations
from collections.abc import Sequence
from copy import deepcopy
import datetime
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typer_config import conf_callback_factory, yaml_loader
import yaml
from janus_core.helpers.utils import build_file_dir
if TYPE_CHECKING:
from ase import Atoms
from typer import Context
from janus_core.cli.types import CorrelationKwargs, TyperDict
from janus_core.helpers.janus_types import (
MaybeSequence,
PathLike,
)
[docs]
def dict_paths_to_strs(dictionary: dict) -> None:
"""
Recursively iterate over dictionary, converting Path values to strings.
Parameters
----------
dictionary
Dictionary to be converted.
"""
for key, value in dictionary.items():
if isinstance(value, dict):
dict_paths_to_strs(value)
elif isinstance(value, Sequence) and not isinstance(value, str):
dictionary[key] = [
str(path) if isinstance(path, Path) else path for path in value
]
elif isinstance(value, Path):
dictionary[key] = str(value)
[docs]
def dict_tuples_to_lists(dictionary: dict) -> None:
"""
Recursively iterate over dictionary, converting tuple values to lists.
Parameters
----------
dictionary
Dictionary to be converted.
"""
for key, value in dictionary.items():
if isinstance(value, dict):
dict_tuples_to_lists(value)
elif isinstance(value, tuple):
dictionary[key] = list(value)
elif isinstance(value, list):
dictionary[key] = [list(x) if isinstance(x, tuple) else x for x in value]
[docs]
def dict_remove_hyphens(dictionary: dict) -> dict:
"""
Recursively iterate over dictionary, replacing hyphens with underscores in keys.
Parameters
----------
dictionary
Dictionary to be converted.
Returns
-------
dict
Dictionary with hyphens in keys replaced with underscores.
"""
for key, value in dictionary.items():
if isinstance(value, dict):
dictionary[key] = dict_remove_hyphens(value)
return {k.replace("-", "_"): v for k, v in dictionary.items()}
[docs]
def set_read_kwargs_index(read_kwargs: dict[str, Any]) -> None:
"""
Set default read_kwargs["index"] to final image and check its value is an integer.
To ensure only a single Atoms object is read, slices such as ":" are forbidden.
Parameters
----------
read_kwargs
Keyword arguments to be passed to ase.io.read. If specified,
read_kwargs["index"] must be an integer, and if not, a default value
of -1 is set.
"""
read_kwargs.setdefault("index", -1)
try:
int(read_kwargs["index"])
except ValueError as e:
raise ValueError("`read_kwargs['index']` must be an integer") from e
[docs]
def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
"""
Convert list of TyperDict objects to list of dictionaries.
Parameters
----------
typer_dicts
List of TyperDict objects to convert.
Returns
-------
list[dict]
List of converted dictionaries.
Raises
------
ValueError
If items in list are not converted to dicts.
"""
for i, typer_dict in enumerate(typer_dicts):
typer_dicts[i] = typer_dict.value if typer_dict else {}
if not isinstance(typer_dicts[i], dict):
raise ValueError(
f"""{typer_dicts[i]} must be passed as a dictionary wrapped in quotes.\
For example, "{{'key': value}}" """
)
return typer_dicts
[docs]
def yaml_converter_loader(config_file: str) -> dict[str, Any]:
"""
Load yaml configuration and replace hyphens with underscores.
Parameters
----------
config_file
Yaml configuration file to read.
Returns
-------
dict[str, Any]
Dictionary with loaded configuration.
"""
if not config_file:
return {}
config = yaml_loader(config_file)
# Replace all "-"" with "_" in conf
return dict_remove_hyphens(config)
yaml_converter_callback = conf_callback_factory(yaml_converter_loader)
[docs]
def start_summary(
*,
command: str,
summary: Path,
config: dict[str, Any],
info: dict[str, Any],
output_files: dict[str, PathLike],
) -> None:
"""
Write initial summary contents.
Parameters
----------
command
Name of CLI command being used.
summary
Path to summary file being saved.
config
Inputs to CLI command to save.
info
Extra information to save.
output_files
Output files with labels to be generated by CLI command.
"""
config.pop("config", None)
output_files["summary"] = summary.absolute()
summary_contents = {
"command": f"janus {command}",
"start_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S"),
"config": config,
"info": info,
"output_files": output_files,
}
# Convert all paths to strings in inputs nested dictionary
dict_paths_to_strs(summary_contents)
dict_tuples_to_lists(summary_contents)
build_file_dir(summary)
with open(summary, "w", encoding="utf8") as outfile:
yaml.dump(summary_contents, outfile, default_flow_style=False)
[docs]
def carbon_summary(*, summary: Path, log: Path) -> None:
"""
Calculate and write carbon tracking summary.
Parameters
----------
summary
Path to summary file being saved.
log
Path to log file with carbon emissions saved.
"""
with open(log, encoding="utf8") as file:
logs = yaml.safe_load(file)
emissions = sum(
lg["message"]["emissions"]
for lg in logs
if isinstance(lg["message"], dict) and "emissions" in lg["message"]
)
with open(summary, "a", encoding="utf8") as outfile:
yaml.dump({"emissions": emissions}, outfile, default_flow_style=False)
[docs]
def end_summary(summary: Path) -> None:
"""
Write final time to summary and close.
Parameters
----------
summary
Path to summary file being saved.
"""
with open(summary, "a", encoding="utf8") as outfile:
yaml.dump(
{"end_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S")},
outfile,
default_flow_style=False,
)
logging.shutdown()
[docs]
def get_struct_info(
*,
struct: MaybeSequence[Atoms],
struct_path: Path,
) -> dict[str, Any]:
"""
Add structure information to a dictionary.
Parameters
----------
struct
Structure to be simulated.
struct_path
Path of structure file.
Returns
-------
dict[str, Any]
Dictionary with structure information.
"""
from ase import Atoms
info = {}
if isinstance(struct, Atoms):
info["struct"] = {
"n_atoms": len(struct),
"struct_path": struct_path,
"formula": struct.get_chemical_formula(),
}
elif isinstance(struct, Sequence):
info["traj"] = {
"length": len(struct),
"struct_path": struct_path,
"struct": {
"n_atoms": len(struct[0]),
"formula": struct[0].get_chemical_formula(),
},
}
return info
[docs]
def get_config(*, params: dict[str, Any], all_kwargs: dict[str, Any]) -> dict[str, Any]:
"""
Get configuration and set kwargs dictionaries.
Parameters
----------
params
CLI input parameters from ctx.
all_kwargs
Name and contents of all kwargs dictionaries.
Returns
-------
dict[str, Any]
Input parameters with parsed kwargs dictionaries substituted in.
"""
for param in params:
if param in all_kwargs:
params[param] = all_kwargs[param]
return params
[docs]
def check_config(ctx: Context) -> None:
"""
Check options in configuration file are valid options for CLI command.
Parameters
----------
ctx
Typer (Click) Context within command.
"""
# Compare options from config file (default_map) to function definition (params)
for option in ctx.default_map:
# Check options individually so can inform user of specific issue
if option not in ctx.params:
raise ValueError(f"'{option}' in configuration file is not a valid option")
[docs]
def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]:
"""
Parse CLI CorrelationKwargs to md correlation_kwargs.
Parameters
----------
kwargs
CLI correlation keyword options.
Returns
-------
list[dict]
The parsed correlation_kwargs for md.
"""
from janus_core.processing import observables
parsed_kwargs = []
for name, cli_kwargs in kwargs.items():
arguments = {
"blocks",
"points",
"averaging",
"update_frequency",
"a_kwargs",
"b_kwargs",
"a",
"b",
}
if not (set(cli_kwargs.keys()) <= arguments):
raise ValueError(
"correlation_kwargs got unexpected argument(s)"
f"{set(cli_kwargs.keys()).difference(arguments)}"
)
if "a" not in cli_kwargs and "b" not in cli_kwargs:
raise ValueError("At least one observable must be supplied as 'a' or 'b'")
if "points" not in cli_kwargs:
raise ValueError("Correlation keyword argument 'points' must be specified")
# Accept an Observable to be replicated.
if "b" not in cli_kwargs:
a = cli_kwargs["a"]
b = deepcopy(a)
# Copying Observable, so can copy kwargs as well.
if "b_kwargs" not in cli_kwargs and "a_kwargs" in cli_kwargs:
cli_kwargs["b_kwargs"] = cli_kwargs["a_kwargs"]
elif "a" not in cli_kwargs:
b = cli_kwargs["b"]
a = deepcopy(b)
if "a_kwargs" not in cli_kwargs and "b_kwargs" in cli_kwargs:
cli_kwargs["a_kwargs"] = cli_kwargs["b_kwargs"]
else:
a = cli_kwargs["a"]
b = cli_kwargs["b"]
a_kwargs = cli_kwargs["a_kwargs"] if "a_kwargs" in cli_kwargs else {}
b_kwargs = cli_kwargs["b_kwargs"] if "b_kwargs" in cli_kwargs else {}
# Accept "." in place of one kwargs to repeat.
if a_kwargs == "." and b_kwargs == ".":
raise ValueError("a_kwargs and b_kwargs cannot 'ditto' each other")
if a_kwargs and b_kwargs == ".":
b_kwargs = a_kwargs
elif b_kwargs and a_kwargs == ".":
a_kwargs = b_kwargs
cor_kwargs = {
"name": name,
"points": cli_kwargs["points"],
"a": getattr(observables, a)(**a_kwargs),
"b": getattr(observables, b)(**b_kwargs),
}
for optional in cli_kwargs.keys() & {"blocks", "averaging", "update_frequency"}:
cor_kwargs[optional] = cli_kwargs[optional]
parsed_kwargs.append(cor_kwargs)
return parsed_kwargs