Source code for lightwin.util.helper

"""Define misc helper functions.

.. todo::
    Clean this, check what is still used.

"""

import logging
import re
from collections.abc import Generator, Iterable
from typing import Any, Iterator

import numpy as np
import pandas as pd


# =============================================================================
# For getter and setters
# =============================================================================
[docs] def recursive_items(dictionary: dict[Any, Any]) -> Iterator[str]: """Recursively list all keys of a possibly nested dictionary.""" for key, value in dictionary.items(): if isinstance(value, dict): yield key yield from recursive_items(value) elif hasattr(value, "has"): yield key yield from recursive_items(vars(value)) # for ListOfElements: if isinstance(value, list): yield from recursive_items(vars(value[0])) else: yield key
[docs] def recursive_getter( wanted_key: str, dictionary: dict[str, Any], **kwargs: Any ) -> Any: """Get first key in a possibly nested dictionary.""" if wanted_key in dictionary: return dictionary[wanted_key] for value in dictionary.values(): if isinstance(value, dict): corresp_value = recursive_getter(wanted_key, value, **kwargs) if corresp_value is not None: return corresp_value elif hasattr(value, "get"): corresp_value = value.get(wanted_key, **kwargs) if corresp_value is not None: return corresp_value return None
# ============================================================================= # For lists manipulations # =============================================================================
[docs] def flatten[T](nest: Iterable[T]) -> Iterator[T]: """Flatten nested list of lists of...""" for _in in nest: if isinstance(_in, Iterable) and not isinstance(_in, (str, bytes)): yield from flatten(_in) else: yield _in
[docs] def chunks[T](lst: list[T], n_size: int) -> Generator[list[T], int, None]: """Yield successive ``n_size``-ed chunks from ``lst``. https://stackoverflow.com/questions/312443/how-do-i-split-a-list-into-equal ly-sized-chunks """ for i in range(0, len(lst), n_size): yield lst[i : i + n_size]
[docs] def remove_duplicates[T](iterable: Iterable[T]) -> Iterator[T]: """Create an iterator without duplicates. Taken from: https://stackoverflow.com/questions/32012878/iterator-object-for-removing-duplicates-in-python """ seen = set() for item in iterable: if item in seen: continue seen.add(item) yield item
# ============================================================================= # Messages functions # =============================================================================
[docs] def pd_output(df: pd.DataFrame, header: str = "") -> str: """Return a formatted string representation of a pandas DataFrame.""" width = 100 sep = "=" * width sub_sep = "-" * width header_line = f"{header}\n{sub_sep}\n" if header else "" return f"\n{sep}\n{header_line}{df.to_string()}\n{sep}"
[docs] def pascal_case(message: str) -> str: """Convert a string to Pascal case (as class names). Examples -------- >>> pascal_case("bonjoure sa_vA") "BonjoureSaVa" >>> pascal_case("BonjoureSaVa") "BonjoureSaVa" """ parts = re.split(r"[ _]+", message) return "".join(p[:1].upper() + p[1:] for p in parts if p)
[docs] def get_constructor(name: str, constructors: dict[str, type]) -> type: """Get the proper class from a string and dict of classes.""" pascal_name = pascal_case(name) if pascal_name in constructors: return constructors[pascal_name] if name in constructors: constructor = constructors[name] logging.warning( f"{constructor = } matches the provided {name = }, but consider " f"calling it {pascal_name} for consistency." ) return constructor raise KeyError( f"Neither {pascal_name = } nor {name = } is in {constructors = }" )
[docs] def get_constructors( names: Iterable[str], constructors: dict[str, type] ) -> Generator[type, None, None]: """Get several class constructors from their names.""" return (get_constructor(name, constructors) for name in names)
# TODO: replace nan by ' ' when there is a \n in a pd DataFrame header # def printd(message: str, header: str = '') -> None: # """Print delimited message.""" # pd.options.display.float_format = '{:.6f}'.format # pd.options.display.max_columns = 10 # pd.options.display.max_colwidth = 18 # pd.options.display.width = 250 # # tot = 100 # # my_output = header + "\n" + "-" * tot + "\n" + message.to_string() # # my_output += "\n" + "-" * tot # my_output = pd_output(message, header) # logging.info(my_output)
[docs] def resample( x_1: np.ndarray, y_1: np.ndarray, x_2: np.ndarray, y_2: np.ndarray ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Downsample y_highres(olution) to x_1 or x_2 (the one with low res).""" assert x_1.shape == y_1.shape assert x_2.shape == y_2.shape if x_1.shape > x_2.shape: y_1 = np.interp(x_2, x_1, y_1) x_1 = x_2 return x_1, y_1, x_2, y_2 y_2 = np.interp(x_1, x_2, y_2) x_2 = x_1 return x_1, y_1, x_2, y_2
[docs] def range_vals(name: str, data: np.ndarray | None) -> str: """Return formatted first and last value of the ``data`` array.""" out = f"{name:17s}" if data is None: return out + " (not set)\n" out += f"{data[0]:+9.5e} -> {data[-1]:+9.5e} | {data.shape}\n" return out
[docs] def range_vals_object(obj: object, name: str) -> str: """Return first and last value of the ``name`` attr from ``obj``.""" val = getattr(obj, name) out = f"{name:17s}" if val is None: return out + " (not set)\n" if isinstance(val, float): return out + f"{val} (single value)\n" out += f"{val[0]:+9.5e} -> {val[-1]:+9.5e} | {val.shape}\n" return out
# ============================================================================= # Files functions # =============================================================================
[docs] def save_energy_phase_tm(lin: object) -> None: """Save energy, phase, transfer matrix as a function of s. s [m] E[MeV] phi[rad] M_11 M_12 M_21 M_22 Parameters ---------- lin : Object of corresponding to desired output. """ n_z = lin.get("z_abs").shape[0] data = np.column_stack( ( lin.get("z_abs"), lin.get("w_kin"), lin.get("phi_abs_array"), np.reshape(lin.transf_mat["tm_cumul"], (n_z, 4)), ) ) filepath = lin.files["results_folder"] + lin.name + "_energy_phase_tm.txt" filepath = filepath.replace(" ", "_") header = ( "s [m] \t W_kin [MeV] \t phi_abs [rad]" + "\t M_11 \t M_12 \t M_21 \t M_22" ) np.savetxt(filepath, data, header=header) logging.info(f"Energy, phase and TM saved in {filepath}")