Source code for lightwin.config.toml_formatter
"""Define several helper functions for proper ``TOML`` formatting."""
import logging
from typing import Any
import numpy as np
[docs]
def format_for_toml(key: str, value: Any, preferred_type: type) -> str:
"""Format the key-value pair so that it matches ``toml`` standard."""
if isinstance(value, dict):
formatted_value = _concat_dict(value)
elif isinstance(value, (list, np.ndarray)):
formatted_value = _format_list(value, preferred_type)
else:
formatted_value = _format_value(key, value, preferred_type)
return f"{key} = {formatted_value}"
[docs]
def _concat_dict(value: dict[str, Any]) -> str:
"""Adapt Python dict to toml inline table."""
entries = (
format_for_toml(subkey, subval, type(subval))
for subkey, subval in value.items()
)
return "{ " + ", ".join(entries) + " }"
[docs]
def _format_list(value: list | np.ndarray, preferred_type: type) -> str:
"""Format a list of values, including handling lists of dicts."""
if all(isinstance(item, dict) for item in value):
formatted_items = [_concat_dict(item) for item in value]
else:
if isinstance(value, np.ndarray):
value = value.tolist()
formatted_items = [str(item) for item in value]
return "[ " + ", ".join(formatted_items) + " ]"
[docs]
def _format_value(key: str, value: Any, preferred_type: type) -> str:
"""Format the value so that it matches ``toml`` standard."""
if preferred_type is str:
return _str_toml(key, value)
if preferred_type is bool:
return _bool_toml(key, value)
return f"{value}"
[docs]
def _str_toml(key: str, value: Any) -> str:
"""Surround value with quotation marks."""
if not isinstance(value, str):
try:
value = str(value)
except TypeError:
msg = (
f"You gave to {key = } the {value = }, which is not "
"broadcastable to a string."
)
logging.error(msg)
raise TypeError(msg)
return '"' + value + '"'
[docs]
def _bool_toml(key: str, value: Any) -> str:
"""Return 'true' or 'false'."""
if not isinstance(value, bool):
try:
value = bool(value)
except TypeError:
msg = (
f"You gave to {key = } the {value = }, which is not "
"broadcastable to a bool."
)
logging.error(msg)
raise TypeError(msg)
if value:
return "true"
return "false"