"""Define a list-based class holding all the :class:`.Fault` to fix.
We also define :func:`fault_scenario_factory`, a factory function creating all
the required :class:`FaultScenario` objects.
"""
import datetime
import logging
import time
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Self
from lightwin.beam_calculation.beam_calculator import BeamCalculator
from lightwin.beam_calculation.simulation_output.simulation_output import (
SimulationOutput,
)
from lightwin.beam_calculation.tracewin.tracewin import TraceWin
from lightwin.core.accelerator.accelerator import Accelerator
from lightwin.core.elements.element import Element
from lightwin.core.elements.field_maps.field_map import FieldMap
from lightwin.core.list_of_elements.list_of_elements import sumup_cavities
from lightwin.evaluator.list_of_simulation_output_evaluators import (
FaultScenarioSimulationOutputEvaluators,
)
from lightwin.failures import strategy
from lightwin.failures.fault import Fault
from lightwin.optimisation.algorithms.algorithm import OptimisationAlgorithm
from lightwin.optimisation.algorithms.factory import (
OptimisationAlgorithmFactory,
)
from lightwin.optimisation.design_space.factory import (
DesignSpaceFactory,
get_design_space_factory,
)
from lightwin.optimisation.objective.factory import (
ObjectiveFactory,
ObjectiveMetaFactory,
)
from lightwin.util.helper import pd_output
from lightwin.util.pickling import MyPickler
from lightwin.util.typing import (
REFERENCE_PHASE_POLICY_T,
REFERENCE_PHASES,
REFERENCE_PHASES_T,
)
[docs]
class FaultScenario(list[Fault]):
"""A class to hold all fault related data."""
[docs]
def __init__(
self,
ref_acc: Accelerator,
fix_acc: Accelerator,
beam_calculator: BeamCalculator,
wtf: dict[str, Any],
design_space_factory: DesignSpaceFactory,
fault_idx: list[int] | list[list[int]],
comp_idx: list[list[int]] | None = None,
info_other_sol: list[dict] | None = None,
objective_factory_class: type[ObjectiveFactory] | None = None,
**kwargs,
) -> None:
"""Create the :class:`FaultScenario` and the :class:`.Fault` objects.
Parameters
----------
ref_acc :
The reference linac (nominal or baseline).
fix_acc :
The broken linac to be fixed.
beam_calculator :
The solver that will be called during the optimisation process.
wtf :
What To Fit dictionary. Holds information on the fixing method.
design_space_factory :
An object to easily create the proper :class:`.DesignSpace`.
fault_idx :
List containing the position of the errors. If ``strategy`` is
manual, it is a list of lists (faults already gathered).
comp_idx :
List containing the position of the compensating cavities. If
``strategy`` is manual, it must be provided.
info_other_sol :
Contains information on another fit, for comparison purposes.
objective_factory_class :
If provided, will override the ``objective_preset``. Used to let
user define it's own :class:`.ObjectiveFactory` without altering
the source code.
"""
self.ref_acc = ref_acc
self.fix_acc = fix_acc
self.beam_calculator = beam_calculator
self._transfer_phi0_from_ref_to_broken()
self.wtf = wtf
self.info_other_sol = info_other_sol
self.info = {}
self.optimisation_time: datetime.timedelta
self._design_space_factory = design_space_factory
self._list_of_elements_factory = (
beam_calculator.list_of_elements_factory
)
self._objective_factory_class = objective_factory_class
self._objective_meta_factory = ObjectiveMetaFactory(
self._reference_simulation_output
)
cavities = strategy.failed_and_compensating(
fix_acc.elts, failed=fault_idx, compensating_manual=comp_idx, **wtf
)
faults = self._create_faults(*cavities)
super().__init__(faults)
self._mark_cavities_to_rephase()
for fault in self:
fault.pre_compensation_status()
self._optimisation_algorithm_factory = OptimisationAlgorithmFactory(
opti_method=wtf["optimisation_algorithm"],
beam_calculator=beam_calculator,
reference_simulation_output=self._reference_simulation_output,
**wtf,
)
self._objective_factories: list[ObjectiveFactory] = []
[docs]
def _create_faults(
self, *cavities: Sequence[Sequence[FieldMap]]
) -> list[Fault]:
"""Create the :class:`.Fault` objects.
Parameters
----------
*cavities :
First if the list of gathered failed cavities. Second is the list
of corresponding compensating cavities.
"""
faults = [
Fault(
reference_elts=self.ref_acc.elts,
broken_elts=self.fix_acc.elts,
failed_elements=faulty_cavities,
compensating_elements=compensating_cavities,
)
for faulty_cavities, compensating_cavities in zip(
*cavities, strict=True
)
]
return faults
@property
def _reference_simulation_output(self) -> SimulationOutput:
"""Determine wich :class:`.SimulationOutput` is the reference."""
solvers_already_used = list(self.ref_acc.simulation_outputs.keys())
assert len(solvers_already_used) > 0, (
"You must compute propagation of the beam in the reference linac "
"prior to create a FaultScenario"
)
solv1 = solvers_already_used[0]
reference_simulation_output = self.ref_acc.simulation_outputs[solv1]
return reference_simulation_output
[docs]
def fix_all(self) -> None:
"""Fix all the :class:`.Fault` objects in self."""
start_time = time.monotonic()
simulation_output = self._reference_simulation_output
for fault in self:
simulation_output = self._wrap_fix(fault, simulation_output)
delta_t = datetime.timedelta(seconds=time.monotonic() - start_time)
logging.info(f"Solving all the optimization problems took {delta_t}")
self.optimisation_time = delta_t
successes = [fault.success for fault in self]
self.fix_acc.name = (
f"Fixed ({successes.count(True)} of {len(successes)})"
)
self._evaluate_fit_quality(save=True)
self.fix_acc.elts.store_settings_in_dat(
self.fix_acc.elts.files_info["dat_file"],
exported_phase=self.beam_calculator.reference_phase_policy,
save=True,
)
[docs]
def _wrap_fix(
self, fault: Fault, simulation_output: SimulationOutput
) -> SimulationOutput:
"""Fix the fault and recompute propagation with new settings.
Orchestrates:
- build :class:`.DesignSpace`
- build :class:`.ObjectiveFactory` (objectives + residuals routine)
- create :class:`.OptimisationAlgorithm` (solver) using factories
above
- run :meth:`.Fault.fix`
- postprocess and logging
Parameters
----------
fault :
The fault to fix.
simulation_output :
The most recent simulation, that includes the compensation settings
of all :class:`.Fault` upstream of ``fault``.
Returns
-------
Most recent simulation, that includes the compensation settings of
upstream :class:`.Fault` as well as of this one.
"""
optimisation_algorithm = self._prepare_fix_objects(
fault, simulation_output
)
fault.fix(optimisation_algorithm)
simulation_output = fault.postprocess_fix(
self.fix_acc,
self.beam_calculator,
self._reference_simulation_output,
self._reference_phase_policy,
)
# TODO clean following
df_altered = sumup_cavities(
fault.subset_elts, filter=lambda cav: cav.is_altered
)
logging.info(f"Retuned cavities:\n{pd_output(df_altered)}")
fault.subset_elts.store_settings_in_dat(
fault.subset_elts.files_info["dat_file"],
exported_phase=self.beam_calculator.reference_phase_policy,
save=True,
)
return simulation_output
[docs]
def _prepare_fix_objects(
self, fault: Fault, simulation_output: SimulationOutput
) -> OptimisationAlgorithm:
"""Create objects to instantiate the :class:`.OptimisationAlgorithm`."""
design_space = self._design_space_factory.create(
fault.compensating_elements, fault.reference_elements
)
objective_factory = self._objective_meta_factory.create(
self.wtf["objective_preset"],
self._design_space_factory.design_space_kw,
fault.packed_elements,
self._objective_factory_class,
)
self._objective_factories.append(objective_factory)
subset_elts = self._list_of_elements_factory.subset_list_run(
objective_factory.elts_of_compensation_zone,
simulation_output,
self.fix_acc.elts.files_info,
)
fault.subset_elts = subset_elts
logging.info(
"Created a ListOfElements ecompassing a linac subset.\n"
f"Encompasses: {subset_elts[0]} to {subset_elts[1]}\nw_kin_in = "
f"{subset_elts.w_kin_in:.2f} MeV\nphi_abs_in = "
f"{subset_elts.phi_abs_in:.2f} rad"
)
optimisation_algorithm = self._optimisation_algorithm_factory.create(
fault.compensating_elements,
objective_factory,
design_space,
subset_elts,
)
return optimisation_algorithm
[docs]
def _evaluate_fit_quality(
self,
save: bool = True,
id_solver_ref: str | None = None,
id_solver_fix: str | None = None,
) -> None:
"""Compute some quantities on the whole linac to see if fit is good.
Parameters
----------
save :
To tell if you want to save the evaluation.
id_solver_ref :
Id of the solver from which you want reference results. The default
is None. In this case, the first solver is taken
(``beam_calc_param``).
id_solver_fix :
Id of the solver from which you want fixed results. The default is
None. In this case, the solver is the same as for reference.
"""
simulations = self._simulations_that_should_be_compared(
id_solver_ref, id_solver_fix
)
quantities_to_evaluate = (
"w_kin",
"phi_abs",
"envelope_pos_phiw",
"envelope_energy_phiw",
"eps_phiw",
"mismatch_factor_zdelta",
)
my_evaluator = FaultScenarioSimulationOutputEvaluators(
quantities_to_evaluate, self._objective_factories, simulations
)
my_evaluator.run(output=True)
# if save:
# fname = 'evaluations_differences_between_simulation_output.csv'
# out = os.path.join(self.fix_acc.get('beam_calc_path'), fname)
# df_eval.to_csv(out)
[docs]
def _set_evaluation_elements(
self, additional_elt: list[Element] | None = None
) -> list[Element]:
"""Set a the proper list of where to check the fit quality."""
evaluation_elements = [fault.subset_elts[-1] for fault in self]
if additional_elt is not None:
evaluation_elements += additional_elt
evaluation_elements.append(self.fix_acc.elts[-1])
return evaluation_elements
[docs]
def _simulations_that_should_be_compared(
self, id_solver_ref: str | None, id_solver_fix: str | None
) -> tuple[SimulationOutput, SimulationOutput]:
"""Get proper :class:`.SimulationOutput` for comparison."""
if id_solver_ref is None:
id_solver_ref = list(self.ref_acc.simulation_outputs.keys())[0]
if id_solver_fix is None:
id_solver_fix = id_solver_ref
if id_solver_ref != id_solver_fix:
logging.warning(
"You are trying to compare two SimulationOutputs created by "
"two different solvers. This may lead to errors, as "
"interpolations in this case are not implemented yet."
)
ref_simu = self.ref_acc.simulation_outputs[id_solver_ref]
fix_simu = self.fix_acc.simulation_outputs[id_solver_fix]
return ref_simu, fix_simu
[docs]
def pickle(
self, pickler: MyPickler, path: Path | str | None = None
) -> Path:
"""Pickle (save) the object.
This is useful for debug and temporary saves; do not use it for long
time saving.
"""
if path is None:
path = self.fix_acc.accelerator_path / "fault_scenario.pkl"
assert isinstance(path, Path)
pickler.pickle(self, path)
if isinstance(path, str):
path = Path(path)
return path
[docs]
@classmethod
def from_pickle(cls, pickler: MyPickler, path: Path | str) -> Self:
"""Instantiate object from previously pickled file."""
fault_scenario = pickler.unpickle(path)
return fault_scenario # type: ignore
[docs]
def _mark_cavities_to_rephase(self) -> None:
"""Change the status of cavities after first failure.
Only cavities with a reference phase different from ``"phi_0_abs"`` are
altered.
.. todo::
Could probably be simpler.
"""
if self._reference_phase_policy == "phi_0_abs":
return
cavities = self.fix_acc.l_cav
first_failed_index = cavities.index(self[0].failed_elements[0])
cavities_after_first_failure = cavities[first_failed_index:]
cavities_to_rephase = [
c
for c in cavities_after_first_failure
if c.cavity_settings.reference != "phi_0_abs"
]
logging.info(
f"Marking {len(cavities_to_rephase)} cavities as 'to be rephased',"
" because they are after a failed cavity and their reference phase "
"is phi_s or phi_0_rel."
)
for cav in cavities_to_rephase:
cav.update_status("rephased (in progress)")
# =========================================================================
# Reference phase related
# =========================================================================
[docs]
def _transfer_phi0_from_ref_to_broken(self) -> None:
"""Transfer the reference phases from reference linac to broken.
If the absolute initial phases are not kept between reference and
broken linac, it comes down to rephasing the linac. This is what we
want to avoid when :attr:`.BeamCalculator.reference_phase_policy` is
set to ``"phi_0_abs"``.
"""
ref_cavs = (x for x in self.ref_acc.l_cav)
fix_settings = (x.cavity_settings for x in self.fix_acc.l_cav)
for ref_cav, fix_set in zip(ref_cavs, fix_settings):
reference_phase = self._resolve_reference_phase(ref_cav)
fix_set.set_reference(
reference=reference_phase,
phi_ref=getattr(ref_cav.cavity_settings, reference_phase),
ensure_can_be_calculated=False,
)
@property
def _reference_phase_policy(self) -> REFERENCE_PHASE_POLICY_T:
"""Give reference phase policy of :class:`.BeamCalculator`."""
return self.beam_calculator.reference_phase_policy
[docs]
def _resolve_reference_phase(
self, reference_cavity: FieldMap
) -> REFERENCE_PHASES_T:
"""Get the reference phase matching the reference phase policy.
According to the value of
:attr:`.BeamCalculator.reference_phase_policy`:
- ``"phi_0_abs"``, ``"phi_0_rel"``, ``"phi_s"``: take this reference.
- ```"as_in_original_dat"``: take reference from ``reference_cavity``.
"""
if self._reference_phase_policy in REFERENCE_PHASES:
return self._reference_phase_policy
return reference_cavity.cavity_settings.reference
[docs]
def fault_scenario_factory(
accelerators: list[Accelerator],
beam_calc: BeamCalculator,
wtf: dict[str, Any],
design_space: dict[str, Any],
objective_factory_class: type[ObjectiveFactory] | None = None,
**kwargs,
) -> list[FaultScenario]:
"""Create the :class:`FaultScenario` objects (factory template).
Parameters
----------
accelerators :
Holds all the linacs. The first one must be the reference linac,
while all the others will be to be fixed.
beam_calc :
The solver that will be called during the optimisation process.
wtf :
The WhatToFit table of the TOML configuration file.
design_space_kw :
The design space table from the TOML configuration file.
objective_factory_class :
If provided, will override the ``objective_preset``. Used to let user
define it's own :class:`.ObjectiveFactory` without altering the source
code.
Returns
-------
Holds all the initialized :class:`FaultScenario` objects, holding their
already initialied :class:`.Fault` objects.
"""
# TODO may be better to move this to beam_calculator.init_solver_parameters
need_to_force_element_to_index_creation = (TraceWin,)
if isinstance(beam_calc, *need_to_force_element_to_index_creation):
_force_element_to_index_method_creation(accelerators[1], beam_calc)
scenarios_fault_idx = wtf.pop("failed")
scenarios_comp_idx = [None for _ in accelerators[1:]]
if "compensating_manual" in wtf:
scenarios_comp_idx = wtf.pop("compensating_manual")
_ = [
beam_calc.init_solver_parameters(accelerator)
for accelerator in accelerators
]
design_space_factory: DesignSpaceFactory
design_space_factory = get_design_space_factory(**design_space)
fault_scenarios = [
FaultScenario(
ref_acc=accelerators[0],
fix_acc=accelerator,
beam_calculator=beam_calc,
wtf=wtf,
design_space_factory=design_space_factory,
fault_idx=fault_idx,
comp_idx=comp_idx,
objective_factory_class=objective_factory_class,
)
for accelerator, fault_idx, comp_idx in zip(
accelerators[1:], scenarios_fault_idx, scenarios_comp_idx
)
]
return fault_scenarios
[docs]
def _force_element_to_index_method_creation(
accelerator: Accelerator,
beam_calculator: BeamCalculator,
) -> None:
"""Run a first simulation to link :class:`.Element` with their index.
.. note::
To initalize a :class:`.Fault`, you need a sub:class:`.ListOfElements`.
To create the latter, you need a ``_element_to_index`` method. It can
only be created if you know the number of steps in every
:class:`.Element`. So, for :class:`.TraceWin`, we run a first
simulation.
"""
beam_calculator.compute(accelerator)