"""This module holds a factory to create the :class:`.BeamCalculator`."""
import logging
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Literal
from lightwin.beam_calculation.beam_calculator import BeamCalculator
from lightwin.beam_calculation.cy_envelope_1d.envelope_1d import CyEnvelope1D
from lightwin.beam_calculation.envelope_1d.envelope_1d import Envelope1D
from lightwin.beam_calculation.envelope_3d.envelope_3d import Envelope3D
from lightwin.beam_calculation.tracewin.tracewin import TraceWin
BEAM_CALCULATORS = (
"Envelope1D",
"TraceWin",
"Envelope3D",
) #:
BEAM_CALCULATORS_T = Literal["Envelope1D", "TraceWin", "Envelope3D"]
[docs]
def _get_beam_calculator(
tool: BEAM_CALCULATORS_T, flag_cython: bool = False, **kwargs
) -> type:
"""Get the proper :class:`.BeamCalculator` constructor."""
match tool, flag_cython:
case "Envelope1D", False:
return Envelope1D
case "Envelope1D", True:
return CyEnvelope1D
case "Envelope3D", False:
return Envelope3D
case "Envelope3D", True:
logging.warning(
"No Cython implementation for Envelope3D. Using Python implementation."
)
return Envelope3D
case "TraceWin", _:
return TraceWin
case _:
raise ValueError(
f"{tool = } and/or {flag_cython = } not understood."
)
[docs]
class BeamCalculatorsFactory:
"""A class to create :class:`.BeamCalculator` objects."""
[docs]
def __init__(
self,
beam_calculator: dict[str, Any],
files: dict[str, Any],
beam: dict[str, Any],
beam_calculator_post: dict[str, Any] | None = None,
**other_kw: dict,
) -> None:
"""
Set up factory with arguments common to all :class:`.BeamCalculator`.
Parameters
----------
beam_calculator : dict[str, Any]
Configuration entries for the first :class:`.BeamCalculator`, used
for optimisation.
files : dict[str, Any]
Configuration entries for the input/output paths.
beam : dict[str, Any]
Configuration dictionary holding the initial beam parameters.
beam_calculator_post : dict[str, Any] | None
Configuration entries for the second optional
:class:`.BeamCalculator`, used for a more thorough calculation of
the beam propagation once the compensation settings are found.
other_kw : dict
Other keyword arguments, not used for the moment.
"""
self.all_beam_calculator_kw = (beam_calculator,)
if beam_calculator_post is not None:
self.all_beam_calculator_kw = (
beam_calculator,
beam_calculator_post,
)
self._beam_kwargs = beam
self.out_folders = self._set_out_folders(self.all_beam_calculator_kw)
self.beam_calculators_id: list[str] = []
self._patch_to_remove_misunderstood_key()
self._original_dat_dir: Path = files["dat_file"].parent
[docs]
def _set_out_folders(
self,
all_beam_calculator_kw: Sequence[dict[str, Any]],
) -> list[Path]:
"""Set in which subfolder the results will be saved."""
out_folders = [
Path(f"{i}_{kw['tool']}")
for i, kw in enumerate(all_beam_calculator_kw)
]
return out_folders
[docs]
def _patch_to_remove_misunderstood_key(self) -> None:
"""Patch to remove a key not understood by TraceWin. Declare id list.
.. todo::
fixme
"""
for beam_calculator_kw in self.all_beam_calculator_kw:
if "simulation type" in beam_calculator_kw:
del beam_calculator_kw["simulation type"]
[docs]
def run(
self, tool: BEAM_CALCULATORS_T, **beam_calculator_kw
) -> BeamCalculator:
"""Create a single :class:`.BeamCalculator`.
Parameters
----------
tool : Literal["Envelope1D", "TraceWin", "Envelope3D"]
The name of the beam calculator to construct.
Returns
-------
BeamCalculator
An instance of the proper beam calculator.
"""
beam_calculator_class = _get_beam_calculator(
tool, **beam_calculator_kw
)
beam_calculator = beam_calculator_class(
out_folder=self.out_folders.pop(0),
default_field_map_folder=self._original_dat_dir,
beam_kwargs=self._beam_kwargs,
**beam_calculator_kw,
)
self.beam_calculators_id.append(beam_calculator.id)
return beam_calculator
[docs]
def run_all(self) -> tuple[BeamCalculator, ...]:
"""Create all the beam calculators."""
beam_calculators = [
self.run(**beam_calculator_kw)
for beam_calculator_kw in self.all_beam_calculator_kw
]
self._check_consistency_absolute_phases(beam_calculators)
return tuple(beam_calculators)
[docs]
def _check_consistency_absolute_phases(
self, beam_calculators: Sequence[BeamCalculator]
) -> None:
"""Check that ``flag_phi_abs`` is the same for all solvers."""
flag_phi_abs = {
beam_calculator: beam_calculator.flag_phi_abs
for beam_calculator in beam_calculators
}
n_unique_values = len(set(flag_phi_abs.values()))
if n_unique_values > 1:
logging.warning(
"The different BeamCalculator objects have different values "
"for flag_phi_abs. This may lead to inconstencies when "
f"cavities fail.\n{flag_phi_abs = }"
)