"""Define helper functions to set up LightWin workflow."""
import logging
from collections.abc import Collection
from typing import Any
from lightwin.beam_calculation.beam_calculator import BeamCalculator
from lightwin.beam_calculation.factory import BeamCalculatorsFactory
from lightwin.beam_calculation.simulation_output.simulation_output import (
SimulationOutput,
)
from lightwin.core.accelerator.accelerator import Accelerator
from lightwin.core.accelerator.factory import AcceleratorFactory
from lightwin.failures.fault_scenario import (
FaultScenario,
FaultScenarioFactory,
)
from lightwin.optimisation.objective.factory import ObjectiveFactory
from lightwin.util.typing import BeamKwargs, ConfigKw
from lightwin.visualization import plot
[docs]
def set_up_solvers(
files: dict[str, Any],
beam: BeamKwargs,
beam_calculator: dict[str, Any],
beam_calculator_post: dict[str, Any] | None = None,
reset_factory: bool = False,
**config,
) -> tuple[BeamCalculator, ...]:
"""Create the beam calculators.
Parameters
----------
files :
Configuration entries for the input/output paths.
beam :
Configuration dictionary holding the initial beam parameters.
beam_calculator :
Configuration entries for the first |BC|, used for optimisation.
beam_calculator_post :
Configuration entries for the second optional |BC|, used for a more
thorough calculation of the beam propagation once the compensation
settings are found.
reset_factory :
Force creation of a new :class:`.BeamCalculatorsFactory`, reset the
|BC| counter. Use this if you want to change the ``files`` or ``beam``
without restarting the Python kernel -- for example during ``pytest``.
config :
Other ``TOML`` configuration dictionaries.
Returns
-------
The objects that will compute the beam propagation.
"""
if reset_factory:
BeamCalculatorsFactory.reset()
BeamCalculator.reset_ids()
factory = BeamCalculatorsFactory(files=files, beam=beam)
beam_calculators = factory.run_all([beam_calculator, beam_calculator_post])
return beam_calculators
[docs]
def set_up_accelerators(
config: ConfigKw, beam_calculators: tuple[BeamCalculator, ...]
) -> dict[int, list[Accelerator]]:
"""Create the accelerators.
.. note::
If an automatic study is asked, the ``wtf`` dictionary is updated to
explicitly mention the list of failed cavities.
Parameters
----------
config :
The full ``TOML`` configuration dictionary.
beam_calculators :
The objects that will compute the beam propagation.
Returns
-------
Dictionary where keys are |FS| indexes, and values are lists of
corresponding |A|. First index corresponds to reference accelerator (no
failure).
"""
factory = AcceleratorFactory(beam_calculators, **config)
wtf: dict[str, Any] | None = config.get("wtf")
accelerators, updated_wtf = factory.create_all(wtf)
if updated_wtf is not None:
config["wtf"] = updated_wtf
return accelerators
[docs]
def set_up_faults(
config: ConfigKw,
beam_calculator: BeamCalculator,
accelerators: dict[int, list[Accelerator]],
objective_factory_class: type[ObjectiveFactory] | None = None,
**kwargs,
) -> list[FaultScenario]:
"""Create all the |F|, gather them in |FS|.
Parameters
----------
config :
The full ``TOML`` configuration dict.
beam_calculator :
The object that will be used for the optimization. Usually, a fast
solver such as :class:`.CyEnvelope1D`.
accelerators :
Dictionary where keys are |FS| indexes, and values are lists of
corresponding |A|. First index corresponds to reference accelerator (no
failure).
objective_factory_class :
If provided, will override the ``objective_preset``. Used to let user
define its own :class:`.ObjectiveFactory` without altering the source
code.
Returns
-------
The instantiated fault scenarios.
"""
design_space_kw = config.get("design_space", None)
if design_space_kw is None:
raise ValueError("design_space configuration is necessary")
factory = FaultScenarioFactory(
accelerators,
beam_calculator,
design_space_kw,
objective_factory_class=objective_factory_class,
)
wtf = config.get("wtf", None)
if wtf is None:
raise ValueError("wtf configuration is necessary")
return factory.create(**wtf)
[docs]
def set_up(config: ConfigKw, **kwargs) -> tuple[
tuple[BeamCalculator, ...],
dict[int, list[Accelerator]],
list[FaultScenario] | None,
list[SimulationOutput],
]:
"""Create all the objects used in a typical LightWin simulation.
Parameters
----------
config :
The full ``TOML`` configuration dictionary.
Returns
-------
beam_calculators :
The objects to compute the beam. Typically, they are two: one for the
optimization, and a second slower one to run a more precise simulation.
accelerators :
Dictionary where keys are |FS| indexes, and values are lists of
corresponding |A|. First index corresponds to reference accelerator (no
failure).
fault_scenarios :
The created failures. Will be None if no ``"wtf"`` entry was given in
``config``.
ref_simulations_outputs :
A reference |SO| corresponding to the nominal linac per |BC|.
"""
beam_calculators = set_up_solvers(**config)
accelerators = set_up_accelerators(config, beam_calculators)
ref_simulations_outputs = [
x.compute(accelerators[0][0]) for x in beam_calculators
]
fault_scenarios = None
if "wtf" in config:
fault_scenarios = set_up_faults(
config, beam_calculators[0], accelerators, **kwargs
)
return (
beam_calculators,
accelerators,
fault_scenarios,
ref_simulations_outputs,
)
[docs]
def fix(fault_scenarios: Collection[FaultScenario] | None) -> None:
"""Fix all the generated faults.
Parameters
----------
fault_scenarios :
The created failures. Will be None if no ``"wtf"`` entry was given in
``config``.
"""
if fault_scenarios is None:
logging.info("No fault was set!")
return
for fault_scenario in fault_scenarios:
fault_scenario.fix_all()
[docs]
def recompute(
beam_calculators: Collection[BeamCalculator],
references: Collection[SimulationOutput],
accelerators: dict[int, list[Accelerator]],
) -> list[list[SimulationOutput]]:
"""Recompute accelerator after a fix with more precision.
.. todo::
Maybe in some cases we want to also recompute the unpickled
Accelerators.
Parameters
----------
beam_calculators :
One or several beam calculators.
references :
A reference |SO| per |BC|, ideally generated by the same |BC|.
accelerators :
Dictionary where keys are |FS| indexes, and values are lists of
corresponding |A|. First index corresponds to reference accelerator (no
failure).
Returns
-------
A nested list of simulation results.
"""
to_recompute = [sublist[0] for sublist in accelerators.values()]
reference_accelerator = to_recompute.pop(0)
assert reference_accelerator.index == 0
simulation_outputs = [
[
beam_calculator.compute(
accelerator, ref_simulation_output=reference_simulation
)
for accelerator in to_recompute
if not accelerator.is_unpickled
]
for beam_calculator, reference_simulation in zip(
beam_calculators, references, strict=True
)
]
return simulation_outputs
[docs]
def run_simulation(
config: ConfigKw, **kwargs
) -> list[FaultScenario] | dict[int, list[Accelerator]]:
"""Compute propagation of beam; if failures are defined, fix them.
Parameters
----------
config :
The full TOML configuration dict.
Returns
-------
If no failure is defined, return the created accelerators. If failures
were defined, return the full fault scenarios. Note that you can access
the accelerator objects with ``FaultScenario.ref_acc`` and
``FaultScenario.fix_acc``.
"""
beam_calculators, accelerators, fault_scenarios, ref_simulation_output = (
set_up(config, **kwargs)
)
if fault_scenarios is None:
plot.factory(accelerators, **config)
return accelerators
fix(fault_scenarios)
recompute(
beam_calculators[1:],
ref_simulation_output[1:],
accelerators,
)
plot.factory(accelerators, fault_scenarios=fault_scenarios, **config)
return fault_scenarios
[docs]
def run_simulation_new(
config: ConfigKw, **kwargs
) -> tuple[dict[int, list[Accelerator]], list[FaultScenario] | None]:
"""Compute propagation of beam; if failures are defined, fix them.
Parameters
----------
config :
The full TOML configuration dict.
Returns
-------
accelerators : dict[int, list[Accelerator]]
Keys are |FS| indexes (0 is for reference). Values are corresponding
|A| as a list; there is typically one |A| in each list, and additional
ones are unpickled.
fault_scenarios :
Returned if failure(s) were defined.
"""
beam_calculators, accelerators, fault_scenarios, ref_simulation_output = (
set_up(config, **kwargs)
)
if fault_scenarios is None:
plot.factory(accelerators, **config)
return accelerators, None
fix(fault_scenarios)
recompute(
beam_calculators[1:],
ref_simulation_output[1:],
accelerators,
)
plot.factory(accelerators, fault_scenarios=fault_scenarios, **config)
return accelerators, fault_scenarios