"""Define a list-based class holding all the |F| to fix.
We also define :func:`fault_scenario_factory`, a factory function creating all
the required |FS| objects.
"""
import datetime
import logging
import time
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Self, overload
from lightwin.beam_calculation.beam_calculator import BeamCalculator
from lightwin.beam_calculation.simulation_output.simulation_output import (
SimulationOutput,
)
from lightwin.beam_calculation.tracewin.tracewin import TraceWin
from lightwin.core.accelerator.accelerator import Accelerator
from lightwin.core.elements.element import Element
from lightwin.core.elements.field_maps.field_map import FieldMap
from lightwin.core.list_of_elements.list_of_elements import (
ELEMENTS_ID_T,
NESTED_ELEMENTS_ID,
sumup_cavities,
)
from lightwin.evaluator.list_of_simulation_output_evaluators import (
FaultScenarioSimulationOutputEvaluators,
)
from lightwin.failures import strategy
from lightwin.failures.fault import Fault
from lightwin.optimisation.algorithms.algorithm import OptimisationAlgorithm
from lightwin.optimisation.algorithms.factory import (
OptimisationAlgorithmFactory,
)
from lightwin.optimisation.design_space.factory import (
DesignSpaceFactory,
get_design_space_factory,
)
from lightwin.optimisation.objective.factory import (
ObjectiveFactory,
ObjectiveMetaFactory,
)
from lightwin.util.helper import pd_output
from lightwin.util.pickling import MyPickler
from lightwin.util.typing import (
REFERENCE_PHASE_POLICY_T,
REFERENCE_PHASES,
REFERENCE_PHASES_T,
DesignSpaceKw,
)
[docs]
class FaultScenario(list[Fault]):
"""A class to hold all fault related data."""
[docs]
def __init__(
self,
ref_acc: Accelerator,
fix_acc: Accelerator,
beam_calculator: BeamCalculator,
wtf: dict[str, Any],
design_space_factory: DesignSpaceFactory,
fault_idx: ELEMENTS_ID_T | NESTED_ELEMENTS_ID,
comp_idx: NESTED_ELEMENTS_ID | None = None,
info_other_sol: list[dict] | None = None,
objective_factory_class: type[ObjectiveFactory] | None = None,
**kwargs,
) -> None:
"""Create the |FS| and the |F| objects.
Parameters
----------
ref_acc :
The reference linac (nominal or baseline).
fix_acc :
The broken linac to be fixed.
beam_calculator :
The solver that will be called during the optimisation process.
wtf :
What To Fit dictionary. Holds information on the fixing method.
design_space_factory :
An object to easily create the proper :class:`.DesignSpace`.
fault_idx :
List containing the position of the errors. If ``strategy`` is
manual, it is a list of lists (faults already gathered).
comp_idx :
List containing the position of the compensating cavities. If
``strategy`` is manual, it must be provided.
info_other_sol :
Contains information on another fit, for comparison purposes.
objective_factory_class :
If provided, will override the ``objective_preset``. Used to let
user define it's own :class:`.ObjectiveFactory` without altering
the source code.
"""
self.ref_acc = ref_acc
self.fix_acc = fix_acc
self.beam_calculator = beam_calculator
self.wtf = wtf
self.info_other_sol = info_other_sol
self.info = {}
self.optimisation_time: datetime.timedelta
#: If optimization should be skipped for these :class:`.Fault`\s.
self.skip_optimization = fix_acc.status == "fix"
self._design_space_factory = design_space_factory
self._list_of_elements_factory = (
beam_calculator.list_of_elements_factory
)
self._objective_factory_class = objective_factory_class
self._objective_meta_factory = ObjectiveMetaFactory(
self._reference_simulation_output
)
self._ensure_phi0_is_preserved()
cavities = strategy.failed_and_compensating(
fix_acc.elts, failed=fault_idx, compensating_manual=comp_idx, **wtf
)
faults = self._create_faults(*cavities)
super().__init__(faults)
self._optimisation_algorithm_factory = OptimisationAlgorithmFactory(
opti_method=wtf["optimisation_algorithm"],
beam_calculator=beam_calculator,
reference_simulation_output=self._reference_simulation_output,
accelerator_id=fix_acc.id,
**wtf,
)
self._objective_factories: list[ObjectiveFactory] = []
if self.skip_optimization:
return
self._mark_cavities_to_rephase()
for fault in self:
fault.pre_compensation_status()
self.beam_calculator.compute(
self.fix_acc,
ref_simulation_output=self._reference_simulation_output,
)
[docs]
def _create_faults(
self, *cavities: Sequence[Sequence[FieldMap]]
) -> list[Fault]:
"""Create the |F| objects.
Parameters
----------
*cavities :
First if the list of gathered failed cavities. Second is the list
of corresponding compensating cavities.
"""
faults = [
Fault(
reference_elts=self.ref_acc.elts,
broken_elts=self.fix_acc.elts,
failed_elements=faulty_cavities,
compensating_elements=compensating_cavities,
skip_optimization=self.skip_optimization,
)
for faulty_cavities, compensating_cavities in zip(
*cavities, strict=True
)
]
return faults
@property
def _reference_simulation_output(self) -> SimulationOutput:
"""Determine wich |SO| is the reference."""
solvers_already_used = list(self.ref_acc.simulation_outputs.keys())
assert len(solvers_already_used) > 0, (
"You must compute propagation of the beam in the reference linac "
"prior to create a FaultScenario"
)
solv1 = solvers_already_used[0]
reference_simulation_output = self.ref_acc.simulation_outputs[solv1]
return reference_simulation_output
[docs]
def fix_all(self) -> None:
"""Fix all the |F| objects in self."""
start_time = time.monotonic()
simulation_output = self._reference_simulation_output
for fault in self:
simulation_output = self._wrap_fix(fault, simulation_output)
delta_t = datetime.timedelta(seconds=time.monotonic() - start_time)
logging.info(f"Solving all the optimization problems took {delta_t}")
self.optimisation_time = delta_t
self.fix_acc.status = "fix"
self._evaluate_fit_quality(save=True)
if self.skip_optimization:
return
self.fix_acc.keep(
simulation_output=simulation_output,
exported_phase=self.beam_calculator.reference_phase_policy,
beam_calculator_id=self.beam_calculator.id,
skip_pickle=False,
)
[docs]
def _wrap_fix(
self, fault: Fault, simulation_output: SimulationOutput
) -> SimulationOutput:
"""Fix the fault and recompute propagation with new settings.
Orchestrates:
- build :class:`.DesignSpace`
- build :class:`.ObjectiveFactory` (objectives + residuals routine)
- create :class:`.OptimisationAlgorithm` (solver) using factories
above
- run :meth:`.Fault.fix`
- postprocess and logging
Parameters
----------
fault :
The fault to fix.
simulation_output :
The most recent simulation, that includes the compensation settings
of all |F| upstream of ``fault``.
Returns
-------
Most recent simulation, that includes the compensation settings of
upstream |F| as well as of this one.
"""
optimisation_algorithm = self._prepare_fix_objects(
fault, simulation_output
)
fault.fix(optimisation_algorithm)
simulation_output = fault.postprocess_fix(
self.fix_acc,
self.beam_calculator,
self._reference_simulation_output,
self._reference_phase_policy,
)
# TODO clean following
df_altered = sumup_cavities(
fault.subset_elts, filter=lambda cav: cav.is_altered
)
logging.info(f"Retuned cavities:\n{pd_output(df_altered)}")
fault.subset_elts.store_settings_in_dat(
fault.subset_elts.files_info["dat_file"],
exported_phase=self.beam_calculator.reference_phase_policy,
save=True,
)
return simulation_output
[docs]
def _prepare_fix_objects(
self, fault: Fault, simulation_output: SimulationOutput
) -> OptimisationAlgorithm:
"""Create objects to instantiate the :class:`.OptimisationAlgorithm`.
In particular:
- build :class:`.DesignSpace`
- build :class:`.ObjectiveFactory`
- extract |LOE| corresponding to zone to recompute
- this ``subset_elts`` is kept as a |F| attribute.
- create :class:`.OptimisationAlgorithm`.
- if :attr:`.skip_optimization` is ``True``, we rather
instantiate the special :class:`.PredefinedSolution`.
Parameters
----------
fault :
The fault to fix.
simulation_output :
The most recent simulation, that includes the compensation settings
of all |F| upstream of ``fault``.
Returns
-------
Object that will fix the failure.
"""
design_space = self._design_space_factory.create(
fault.compensating_elements, fault.reference_elements
)
objective_factory = self._objective_meta_factory.create(
objective_preset=self.wtf["objective_preset"],
limits_from_design_space_kw=self._design_space_factory.limits_from_design_space_kw,
packed_elements=fault.packed_elements,
objective_factory_class=self._objective_factory_class,
)
self._objective_factories.append(objective_factory)
subset_elts = self._list_of_elements_factory.subset_list_run(
objective_factory.elts_of_compensation_zone,
simulation_output,
self.fix_acc.elts.files_info,
)
fault.subset_elts = subset_elts
logging.info(
"Created a ListOfElements ecompassing a linac subset.\n"
f"Encompasses: {subset_elts[0]} to {subset_elts[-1]}\nw_kin_in = "
f"{subset_elts.w_kin_in:.2f} MeV\nphi_abs_in = "
f"{subset_elts.phi_abs_in:.2f} rad"
)
if not self.skip_optimization:
return self._optimisation_algorithm_factory.create(
compensating_elements=fault.compensating_elements,
objective_factory=objective_factory,
design_space=design_space,
subset_elts=subset_elts,
)
simulation_outputs = self.fix_acc.simulation_outputs
predefined_simulation_output = simulation_outputs.get(
self.beam_calculator.id
)
if predefined_simulation_output is None:
raise ValueError(
"Fixed accelerator do not have a SimulationOutput calculated "
f"with current solver {self.beam_calculator.id}."
)
predefined_cavity_settings = (
predefined_simulation_output.set_of_cavity_settings
)
return self._optimisation_algorithm_factory.create_from_preset(
compensating_elements=fault.compensating_elements,
objective_factory=objective_factory,
design_space=design_space,
subset_elts=subset_elts,
predefined_cavity_settings=predefined_cavity_settings,
predefined_simulation_output=predefined_simulation_output,
)
[docs]
def _evaluate_fit_quality(
self,
save: bool = True,
id_solver_ref: str | None = None,
id_solver_fix: str | None = None,
) -> None:
"""Compute some quantities on the whole linac to see if fit is good.
Parameters
----------
save :
To tell if you want to save the evaluation.
id_solver_ref :
Id of the solver from which you want reference results. The default
is None. In this case, the first solver is taken
(``beam_calc_param``).
id_solver_fix :
Id of the solver from which you want fixed results. The default is
None. In this case, the solver is the same as for reference.
"""
simulations = self._simulations_that_should_be_compared(
id_solver_ref, id_solver_fix
)
quantities_to_evaluate = (
"w_kin",
"phi_abs",
"envelope_pos_phiw",
"envelope_energy_phiw",
"eps_phiw",
"mismatch_factor_zdelta",
)
my_evaluator = FaultScenarioSimulationOutputEvaluators(
quantities_to_evaluate, self._objective_factories, simulations
)
my_evaluator.run(output=True)
# if save:
# fname = 'evaluations_differences_between_simulation_output.csv'
# out = os.path.join(self.fix_acc.get('beam_calc_path'), fname)
# df_eval.to_csv(out)
[docs]
def _set_evaluation_elements(
self, additional_elt: list[Element] | None = None
) -> list[Element]:
"""Set a the proper list of where to check the fit quality."""
evaluation_elements = [fault.subset_elts[-1] for fault in self]
if additional_elt is not None:
evaluation_elements += additional_elt
evaluation_elements.append(self.fix_acc.elts[-1])
return evaluation_elements
[docs]
def _simulations_that_should_be_compared(
self, id_solver_ref: str | None, id_solver_fix: str | None
) -> tuple[SimulationOutput, SimulationOutput]:
"""Get proper |SO| for comparison."""
if id_solver_ref is None:
id_solver_ref = list(self.ref_acc.simulation_outputs.keys())[0]
if id_solver_fix is None:
id_solver_fix = id_solver_ref
if id_solver_ref != id_solver_fix:
logging.warning(
"You are trying to compare two SimulationOutputs created by "
"two different solvers. This may lead to errors, as "
"interpolations in this case are not implemented yet."
)
ref_simu = self.ref_acc.simulation_outputs[id_solver_ref]
fix_simu = self.fix_acc.simulation_outputs[id_solver_fix]
return ref_simu, fix_simu
[docs]
def pickle(
self, pickler: MyPickler, path: Path | str | None = None
) -> Path:
"""Pickle (save) the object.
This is useful for debug and temporary saves; do not use it for long
time saving.
"""
if path is None:
path = self.fix_acc.accelerator_path / "fault_scenario.pkl"
assert isinstance(path, Path)
pickler.pickle(self, path)
if isinstance(path, str):
path = Path(path)
return path
[docs]
@classmethod
def from_pickle(cls, pickler: MyPickler, path: Path | str) -> Self:
"""Instantiate object from previously pickled file."""
fault_scenario = pickler.unpickle(path)
return fault_scenario # type: ignore
[docs]
def _mark_cavities_to_rephase(self) -> None:
"""Change the status of cavities after first failure.
Only cavities with a reference phase different from ``"phi_0_abs"`` are
altered.
.. todo::
Could probably be simpler.
"""
if self.skip_optimization:
return
if self._reference_phase_policy == "phi_0_abs":
return
cavities = self.fix_acc.l_cav
first_failed_index = cavities.index(self[0].failed_elements[0])
cavities_after_first_failure = cavities[first_failed_index:]
cavities_to_rephase = [
c
for c in cavities_after_first_failure
if c.cavity_settings.reference != "phi_0_abs"
]
logging.info(
f"Marking {len(cavities_to_rephase)} cavities as 'to be rephased',"
" because they are after a failed cavity and their reference phase "
"is phi_s or phi_0_rel."
)
for cav in cavities_to_rephase:
cav.update_status("rephased (in progress)")
# =========================================================================
# Reference phase related
# =========================================================================
[docs]
def _ensure_phi0_is_preserved(self) -> None:
"""Transfer the reference phases from reference linac to broken.
If the absolute initial phases are not kept between reference and
broken linac, it comes down to rephasing the linac. This is what we
want to avoid when :attr:`.BeamCalculator.reference_phase_policy` is
set to ``"phi_0_abs"``.
If the linac was already fixed, we simply skip this step.
"""
if self.skip_optimization:
return
ref_cavs = (x for x in self.ref_acc.l_cav)
fix_settings = (x.cavity_settings for x in self.fix_acc.l_cav)
for ref_cav, fix_set in zip(ref_cavs, fix_settings):
reference_phase = self._resolve_reference_phase(ref_cav)
fix_set.set_reference(
reference=reference_phase,
phi_ref=getattr(ref_cav.cavity_settings, reference_phase),
ensure_can_be_calculated=False,
)
@property
def _reference_phase_policy(self) -> REFERENCE_PHASE_POLICY_T:
"""Give reference phase policy of |BC|."""
return self.beam_calculator.reference_phase_policy
[docs]
def _resolve_reference_phase(
self, reference_cavity: FieldMap
) -> REFERENCE_PHASES_T:
"""Get the reference phase matching the reference phase policy.
According to the value of
:attr:`.BeamCalculator.reference_phase_policy`:
- ``"phi_0_abs"``, ``"phi_0_rel"``, ``"phi_s"``: take this reference.
- ```"as_in_original_dat"``: take reference from ``reference_cavity``.
"""
if self._reference_phase_policy in REFERENCE_PHASES:
return self._reference_phase_policy
return reference_cavity.cavity_settings.reference
[docs]
class FaultScenarioFactory:
"""This objects consistently create |FS|."""
[docs]
def __init__(
self,
accelerators: dict[int, list[Accelerator]],
beam_calc: BeamCalculator,
design_space: DesignSpaceKw,
objective_factory_class: type[ObjectiveFactory] | None = None,
) -> None:
"""Init solver parameters for each non-unpickled |A|.
Parameters
----------
accelerators :
Dictionary where keys are |FS| indexes, and values are lists of
corresponding |A|. First index corresponds to reference accelerator
(no failure).
beam_calc :
The solver that will be called during the optimisation process.
design_space_kw :
The design space table from the TOML configuration file.
objective_factory_class :
If provided, will override the ``objective_preset``. Used to let
user define it's own :class:`.ObjectiveFactory` without altering
the source code.
Returns
-------
Holds all the initialized |FS| objects, holding their already
initialied |F| objects.
"""
self._accelerators = accelerators
self._beam_calculator = beam_calc
self._design_space_factory = get_design_space_factory(**design_space)
self._objective_factory_class = objective_factory_class
for index, acc in accelerators.items():
for a in acc:
if a.is_unpickled:
continue
if isinstance(beam_calc, TraceWin) and index > 0:
_force_element_to_index_method_creation(a, beam_calc)
beam_calc.init_solver_parameters(a)
@overload
def create(
self, failed: NESTED_ELEMENTS_ID, compensating_manual: None, **wtf
) -> list[FaultScenario]: ...
@overload
def create(
self,
failed: list[NESTED_ELEMENTS_ID],
compensating_manual: list[NESTED_ELEMENTS_ID],
**wtf,
) -> list[FaultScenario]: ...
[docs]
def create(
self,
failed: NESTED_ELEMENTS_ID | list[NESTED_ELEMENTS_ID],
compensating_manual: list[NESTED_ELEMENTS_ID] | None = None,
**wtf,
) -> list[FaultScenario]:
"""Instantiate fault scenarios.
Parameters
----------
failed :
Index or name of the failed cavities.
compensating_manual :
List of compensating cavities associated with ``failed``.
wtf :
The WhatToFit table of the ``TOML`` configuration file.
"""
fault_scenarios: list[FaultScenario] = []
if len(failed) != len(self._accelerators) - 1:
raise ValueError(
f"We are creating {len(failed)} FaultScenarios, but we have "
f"{len(self._accelerators) - 1} non-reference groups of "
"Accelerators."
)
ref_acc = self._accelerators[0][0]
for index, accelerators_by_scenario in self._accelerators.items():
if index == 0:
continue
# Position of compensating cavities, when the strategy is manual
comp_idx = (
None
if compensating_manual is None
else compensating_manual[index - 1]
)
scenario = FaultScenario(
ref_acc=ref_acc,
fix_acc=accelerators_by_scenario[0],
beam_calculator=self._beam_calculator,
wtf=wtf,
design_space_factory=self._design_space_factory,
fault_idx=failed[index - 1],
comp_idx=comp_idx,
objective_factory_class=self._objective_factory_class,
)
fault_scenarios.append(scenario)
return fault_scenarios
[docs]
def fault_scenario_factory(
accelerators: list[Accelerator],
beam_calc: BeamCalculator,
wtf: dict[str, Any],
design_space: dict[str, Any],
objective_factory_class: type[ObjectiveFactory] | None = None,
**kwargs,
) -> list[FaultScenario]:
"""Create the |FS| objects (factory template).
.. deprecated:: 0.14.1
Prefer the more flexible:
.. code-block:: python
factory = FaultScenarioFactory(
accelerators=accelerators,
beam_calc=beam_calc,
design_space=design_space,
objective_factory_class=objective_factory_class,
)
fault_scenarios = factory.create(**wtf)
Parameters
----------
accelerators :
Holds all the linacs. The first one must be the reference linac,
while all the others will be to be fixed.
beam_calc :
The solver that will be called during the optimisation process.
wtf :
The WhatToFit table of the TOML configuration file.
design_space_kw :
The design space table from the TOML configuration file.
objective_factory_class :
If provided, will override the ``objective_preset``. Used to let user
define it's own :class:`.ObjectiveFactory` without altering the source
code.
Returns
-------
Holds all the initialized |FS| objects, holding their already
initialied |F| objects.
"""
adapted = {
index: [accelerator] for index, accelerator in enumerate(accelerators)
}
factory = FaultScenarioFactory(
accelerators=adapted,
beam_calc=beam_calc,
design_space=design_space,
objective_factory_class=objective_factory_class,
)
return factory.create(**wtf)
[docs]
def _force_element_to_index_method_creation(
accelerator: Accelerator,
beam_calculator: BeamCalculator,
) -> None:
"""Run a first simulation to link |E| with their index.
.. note::
To initalize a |F|, you need a sub|LOE|. To create the latter, you need
a ``_element_to_index`` method. It can only be created if you know the
number of steps in every |E|. So, for :class:`.TraceWin`, we run a
first simulation.
"""
beam_calculator.compute(accelerator)