"""This module holds a factory to create the |BC|."""
import logging
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Literal, Self
from lightwin.beam_calculation.beam_calculator import BeamCalculator
from lightwin.beam_calculation.cy_envelope_1d.envelope_1d import CyEnvelope1D
from lightwin.beam_calculation.envelope_1d.envelope_1d import Envelope1D
from lightwin.beam_calculation.envelope_3d.envelope_3d import Envelope3D
from lightwin.beam_calculation.tracewin.tracewin import TraceWin
from lightwin.util.typing import (
EXPORT_PHASES_T,
REFERENCE_PHASE_POLICY_T,
BeamKwargs,
)
BEAM_CALCULATORS = (
"Envelope1D",
"TraceWin",
"Envelope3D",
) #:
BEAM_CALCULATORS_T = Literal["Envelope1D", "TraceWin", "Envelope3D"]
[docs]
def _get_beam_calculator(
tool: BEAM_CALCULATORS_T, flag_cython: bool, **kwargs
) -> type:
"""Get the proper |BC| constructor."""
match tool, flag_cython:
case "Envelope1D", False:
return Envelope1D
case "Envelope1D", True:
return CyEnvelope1D
case "Envelope3D", False:
return Envelope3D
case "Envelope3D", True:
logging.warning(
"No Cython implementation for Envelope3D. Using Python implementation."
)
return Envelope3D
case "TraceWin", _:
return TraceWin
case _:
raise ValueError(
f"{tool = } and/or {flag_cython = } not understood."
)
[docs]
class BeamCalculatorsFactory:
"""A class to create |BC| objects.
Respects singleton pattern, so that only one factory can be created.
"""
_instance: Self | None = None
def __new__(cls, *args, **kwargs) -> Self:
"""Ensure that only one instance of object exists."""
if cls._instance is None:
logging.info("Creating new BeamCalculatorsFactory instance.")
cls._instance = super().__new__(cls)
else:
logging.info("Re-using previous BeamCalculatorsFactory instance.")
return cls._instance
[docs]
@classmethod
def reset(cls) -> None:
"""Allow creation of a new factory.
Use this when the ``files`` or the ``beam`` ``TOML`` configuration
dicts were updated.
"""
cls._instance = None
[docs]
def __init__(
self, files: dict[str, Any], beam: BeamKwargs, **kwargs: dict
) -> None:
"""Set up factory with arguments common to all |BC|.
Note
----
This object was designed to work with constant ``files`` and ``beam``.
If those dictionaries happen to change during the execution of the
script, execute `BeamCalculatorsFactory.reset()`.
Note
----
This object was designed to work with constant ``files`` and ``beam``.
If those dictionaries happen to change during the execution of the
script, execute `BeamCalculatorsFactory.reset()`.
Parameters
----------
files :
Configuration entries for the input/output paths.
beam :
Configuration dictionary holding the initial beam parameters.
kwargs :
Other keyword arguments, not used for the moment.
"""
if hasattr(self, "_initialized"):
return
self._beam_kwargs = beam
self._initialized = True
# self._patch_to_remove_misunderstood_key()
self._original_dat_dir: Path = files["dat_file"].parent
self._cache: dict[int, BeamCalculator] = {}
[docs]
def _patch_to_remove_misunderstood_key(
self, beam_calculator_kw: dict[str, Any]
) -> None:
"""Patch to remove a key not understood by TraceWin. Declare id list.
.. todo::
fixme
"""
if "simulation type" in beam_calculator_kw:
del beam_calculator_kw["simulation type"]
[docs]
def run(
self,
reference_phase_policy: REFERENCE_PHASE_POLICY_T,
tool: BEAM_CALCULATORS_T,
export_phase: EXPORT_PHASES_T,
flag_cython: bool = False,
force_new: bool = False,
**beam_calculator_kw,
) -> BeamCalculator:
"""Create a single |BC|.
If a |BC| was already created with this factory and with the same
arguments, we return it instead of instantiating a new one. Unless
``force_new`` is set to ``True``.
If a :class:`.BeamCalculator` was already created with this factory
and with the same arguments, we return it instead of instantiating a
new one. Unless ``force_new`` is set to ``True``.
Parameters
----------
reference_phase_policy :
How reference phase of |CS| will be initialized.
tool :
The name of the beam calculator to construct.
export_phase :
The type of phase you want to export for your ``FIELD_MAP``.
flag_cython :
If the beam calculator involves loading cython field maps.
force_new :
To force creation of a new |BC|.
Returns
-------
An instance of the proper beam calculator.
"""
self._patch_to_remove_misunderstood_key(beam_calculator_kw)
cache_key = self._make_cache_key(
reference_phase_policy=reference_phase_policy,
tool=tool,
export_phase=export_phase,
flag_cython=flag_cython,
**beam_calculator_kw,
)
if cache_key in self._cache and not force_new:
beam_calculator = self._cache[cache_key]
logging.info(
f"Re-using existing BeamCalculator: {beam_calculator.id}"
)
return beam_calculator
beam_calculator_class = _get_beam_calculator(
tool, flag_cython=flag_cython, **beam_calculator_kw
)
beam_calculator = beam_calculator_class(
reference_phase_policy=reference_phase_policy,
default_field_map_folder=self._original_dat_dir,
beam_kwargs=self._beam_kwargs,
flag_cython=flag_cython,
export_phase=export_phase,
**beam_calculator_kw,
)
logging.info(f"Creating new BeamCalculator: {beam_calculator.id}")
self._cache[cache_key] = beam_calculator
return beam_calculator
[docs]
def _make_cache_key(
self,
reference_phase_policy: REFERENCE_PHASE_POLICY_T,
tool: BEAM_CALCULATORS_T,
export_phase: EXPORT_PHASES_T,
flag_cython: bool = False,
**beam_calculator_kw,
) -> int:
"""Create unique cache key to avoid re-creating BeamCalculators."""
key_data = (
reference_phase_policy,
tool,
export_phase,
flag_cython,
_make_hashable(beam_calculator_kw),
)
return hash(key_data)
[docs]
def run_all(
self,
beam_calculators_kw: Sequence[dict[str, Any] | None],
force_new: bool = False,
) -> tuple[BeamCalculator, ...]:
"""Create all the beam calculators."""
beam_calculators = [
self.run(force_new=force_new, **beam_calculator_kw)
for beam_calculator_kw in beam_calculators_kw
if beam_calculator_kw is not None
]
return tuple(beam_calculators)
[docs]
def _make_hashable(value: Any) -> Any:
"""Recursively convert unhashable types to hashable equivalents."""
if isinstance(value, dict):
return tuple(sorted((k, _make_hashable(v)) for k, v in value.items()))
if isinstance(value, (list, tuple)):
return tuple(_make_hashable(v) for v in value)
if isinstance(value, set):
return frozenset(_make_hashable(v) for v in value)
return value