Source code for janus_core.cli.train

"""Set up MLIP training commandline interface."""

from __future__ import annotations

from pathlib import Path
from typing import Annotated

from typer import Option, Typer
import yaml

from janus_core.cli.types import Architectures, LogPath, Summary

app = Typer()


[docs] @app.command() def train( arch: Architectures, mlip_config: Annotated[ Path, Option(help="Configuration file to pass to MLIP CLI.", show_default=False) ], fine_tune: Annotated[ bool, Option(help="Whether to fine-tune a foundational model.") ] = False, file_prefix: Annotated[ Path | None, Option( help=( "Prefix for output files, including directories." "Default directory is ./janus_results." ) ), ] = None, log: LogPath | None = None, tracker: bool = True, summary: Summary | None = None, ) -> None: """ Run training for MLIP by passing a configuration file to the MLIP's CLI. Parameters ---------- arch The achitecture to train with. mlip_config Configuration file to pass to MLIP CLI. fine_tune Whether to fine-tune a foundational model. Default is False. file_prefix Prefix for output files, including directories. Default directory is ./janus_results. log Path to write logs to. Default is inferred from `file_prefix`. tracker Whether to save carbon emissions of calculation in log file and summary. Default is True. summary Path to save summary of inputs, start/end time, and carbon emissions. Default is inferred from `file_prefix`. """ from janus_core.cli.utils import carbon_summary, end_summary, start_summary from janus_core.training.train import train as run_train with open(mlip_config, encoding="utf8") as config_file: config = yaml.safe_load(config_file) match arch: case "mace" | "mace_mp" | "mace_off" | "mace_omol": if fine_tune: if "foundation_model" not in config: raise ValueError( "Please include `foundation_model` in your configuration file" ) if ( config["foundation_model"] not in ( "small", "medium", "large", "small_off", "medium_off", "large_off", ) and not Path(config["foundation_model"]).exists() ): raise ValueError( """ Invalid foundational model. Valid options are: 'small', 'medium','large', 'small_off', 'medium_off', 'large_off', or a path to the model """ ) elif "foundation_model" in config: raise ValueError( "Please include the `--fine-tune` option for fine-tuning" ) case "nequip": if "training_module" not in config: raise ValueError( """There is no top-level training_module section in your configuration file.""" ) if "model" not in config["training_module"]: raise ValueError( """There is no model section in the training_module section of your configuration file.""" ) if "_target_" not in config["training_module"]["model"]: raise ValueError( """There is _target_ section in the model section of your configuration file.""" ) model = config["training_module"]["model"]["_target_"] # See nequip.model.__all__ if fine_tune and model not in ( "nequip.model.ModelFromCheckpoint", "nequip.model.ModelFromPackage", ): raise ValueError( """Fine-tuning requested but there is no checkpoint or package specified in your config.""" ) case _: raise ValueError(f"Unsupported Architecture ({arch})") if file_prefix is None: file_prefix = Path.cwd() / "janus_results" if log is None: log = file_prefix / "train-log.yml" if summary is None: summary = file_prefix / "train-summary.yml" config = { "mlip_config": mlip_config, "fine_tune": fine_tune, "log": log, "tracker": tracker, "summary": summary, } log_kwargs = {"filemode": "w"} if log: log_kwargs["filename"] = log output_files = {"log": log.absolute()} # Save summary information before training begins start_summary( command="train", summary=summary, info={}, config=config, output_files=output_files, ) # Run training run_train( arch, mlip_config, file_prefix, attach_logger=True, log_kwargs=log_kwargs, track_carbon=tracker, ) # Save carbon summary if tracker: carbon_summary(summary=summary, log=log) # Save time after training has finished end_summary(summary)