Source code for lightwin.config.config_manager

"""Load, validate and post-process the configuration.

.. todo::
    Method to generate the .csv files of documentation from the tables, i.e.
    docs/manual/configuration_entries

"""

import logging
import shutil
import tomllib
from importlib import resources
from importlib.resources.abc import Traversable
from pathlib import Path
from typing import Any

from lightwin.config.full_specs import ConfSpec


[docs] class ConfigFileNotFoundError(FileNotFoundError): """Custom exception raised when the configuration file is not found.""" pass
[docs] class InvalidTomlSyntaxError(ValueError): """Custom exception raised for invalid TOML syntax.""" pass
[docs] def process_config( toml_path: Path | str | Traversable, config_keys: dict[str, str], warn_mismatch: bool = False, override: dict[str, dict[str, Any]] | None = None, conf_specs_t: type[ConfSpec] = ConfSpec, ) -> dict[str, dict[str, Any]]: """Load and test the configuration file. Parameters ---------- toml_path : Path to the configuration file. It can be path to a real file or a resource reference. config_keys : Associate the name of LightWin's group of parameters to the entry in the configuration file. warn_mismatch : Raise a warning if a key in a ``override`` sub-dict is not found. override : To override entries in the ``TOML``. If not provided, we keep defaults. conf_specs_t : The specifications that the ``TOML`` must match to be accepted. If not provided, we take a default. Returns ------- A dictionary holding all the keyword arguments that will be passed to LightWin objects, eg ``beam_calculator`` will be passed to :class:`.BeamCalculator`. """ raw_toml = _load_toml(toml_path) toml_fulldict = _process_toml( raw_toml, config_keys, warn_mismatch=warn_mismatch, override=override ) conf_specs = conf_specs_t(**config_keys) if isinstance(toml_path, str): toml_path = Path(toml_path) if isinstance(toml_path, Path): conf_specs.prepare(toml_fulldict, toml_folder=toml_path.parent) return toml_fulldict with resources.as_file(toml_path) as extracted_path: conf_specs.prepare(toml_fulldict, toml_folder=extracted_path.parent) return toml_fulldict
[docs] def _load_toml( toml_path: Path | str | Traversable, ) -> dict[str, dict[str, Any]]: """Load the ``TOML`` and extract the dicts asked by user. Parameters ---------- toml_path : Path to the configuration file. It can be path to a real file or a resource reference. Returns ------- Dictionary holding the whole ``TOML`` file. """ if isinstance(toml_path, (str, Path)): toml_path = Path(toml_path) if not toml_path.is_file(): raise ConfigFileNotFoundError( f"The file {toml_path} does not exist." ) try: with open(toml_path, "rb") as f: raw_toml = tomllib.load(f) except tomllib.TOMLDecodeError as e: raise InvalidTomlSyntaxError( f"Invalid TOML syntax in file {toml_path}: {e}" ) return raw_toml if isinstance(toml_path, Traversable): try: with toml_path.open("rb") as f: raw_toml = tomllib.load(f) except tomllib.TOMLDecodeError as e: raise InvalidTomlSyntaxError( f"Invalid TOML syntax in resource {toml_path}: {e}" ) return raw_toml raise TypeError( f"Unsupported type for `config_path`: {type(toml_path)}. Expected " "str, Path, or Traversable." )
[docs] def _process_toml( raw_toml: dict[str, dict[str, Any]], config_keys: dict[str, str], *, warn_mismatch: bool, override: dict[str, dict[str, Any]] | None, ) -> dict[str, dict[str, Any]]: """Extract the dicts asked by user. Override some keys if requested. Parameters ---------- raw_toml : Dictionary holding the whole ``TOML`` file. config_keys : Keys will be the keys of the output. Values are the name of the tables in the configuration file. If ``config_keys = {"beam": "proton_beam"}`` , the output will look like ``{"beam": {<content of [proton_beam]>}}``. warn_mismatch : Check if there are discrepancies between ``override`` and the keys or dicts to override in ``config_keys``. override : To override some entries of the output dictionary, before even testing it. Returns ------- A dictionary which keys are the keys of ``config_keys``, and the values are dictionaries holding corresponding table entries from the configuration file. """ toml_fulldict = {} for key, value in config_keys.items(): if value not in raw_toml: raise KeyError( f"Expected table '{value}' for key '{key}' not found in the " "TOML file." ) toml_fulldict[key] = raw_toml[value] if override: _override_some_toml_entries(toml_fulldict, warn_mismatch, **override) return toml_fulldict
[docs] def _override_some_toml_entries( toml_fulldict: dict[str, dict[str, Any]], warn_mismatch: bool, **override: dict[str, Any], ) -> None: """Override some entries before testing.""" for over_key, over_subdict in override.items(): assert over_key in toml_fulldict, ( f"You want to override entries in {over_key = }, which was not " f"found in {toml_fulldict.keys() = }" ) conf_subdict = toml_fulldict[over_key] for key, val in over_subdict.items(): if warn_mismatch and key not in conf_subdict: logging.warning( f"You want to override {key = }, which was not found in " f"{conf_subdict.keys() = }. Setting it anyway..." ) conf_subdict[key] = val
[docs] def dict_to_toml( toml_fulldict: dict[str, dict[str, Any]], toml_path: Path, conf_specs: ConfSpec, allow_overwrite: bool = False, original_toml_folder: Path | None = None, **kwargs, ) -> None: """Write the provided configuration dict to a ``TOML`` file. Parameters ---------- toml_fulldict : The configuration as a nested dictionary. The keys will be used as table entries. toml_path : Where to save the ``TOML``. conf_specs : Holds the template to be respected. In particular, the type of the values in the different tables. allow_overwrite : If a pre-existing ``TOML`` can be overwritten. The default is False, in which case an error will be raised. original_toml_folder : Where the original ``TOML`` was; this is used to resolve paths relative to this location. """ if _indue_overwritting(toml_path, allow_overwrite): return strings = conf_specs.to_toml_strings( toml_fulldict, original_toml_folder=original_toml_folder, **kwargs ) with open(toml_path, "w") as f: for dict_entry_string in strings: f.write(dict_entry_string) f.write("\n") logging.info(f"New ``TOML`` written in {toml_path}") return
[docs] def _indue_overwritting( toml_path: Path, allow_overwrite: bool = False, ) -> bool: """Ensure that ``TOML`` will not be overwritten if not wanted.""" if not toml_path.exists(): return False logging.info( f"A .toml already exists at {toml_path = } and may be overwritten." ) if not allow_overwrite: logging.error("Overwritting not permitted. Skipping action...") return True old = toml_path.with_suffix(".toml.old") logging.info(f"Copying the old one to {old}, just in case...") shutil.copy(toml_path, old) return False