"""Define a factory function to create :class:`.OptimisationAlgorithm`."""
import logging
from abc import ABCMeta
from collections.abc import Collection, Mapping
from typing import Any, Literal
from lightwin.beam_calculation.beam_calculator import BeamCalculator
from lightwin.beam_calculation.simulation_output.simulation_output import (
SimulationOutput,
)
from lightwin.core.elements.element import Element
from lightwin.core.elements.field_maps.cavity_settings import CavitySettings
from lightwin.core.elements.field_maps.field_map import FieldMap
from lightwin.core.list_of_elements.list_of_elements import ListOfElements
from lightwin.failures.set_of_cavity_settings import SetOfCavitySettings
from lightwin.optimisation.algorithms.algorithm import OptimisationAlgorithm
from lightwin.optimisation.algorithms.bayesian_optimization import (
BayesianOptimizationLW,
)
from lightwin.optimisation.algorithms.differential_evolution import (
DifferentialEvolution,
)
from lightwin.optimisation.algorithms.downhill_simplex import DownhillSimplex
from lightwin.optimisation.algorithms.downhill_simplex_penalty import (
DownhillSimplexPenalty,
)
from lightwin.optimisation.algorithms.explorator import Explorator
from lightwin.optimisation.algorithms.least_squares import LeastSquares
from lightwin.optimisation.algorithms.least_squares_penalty import (
LeastSquaresPenalty,
)
from lightwin.optimisation.algorithms.nsga import (
NSGA3Algorithm,
NSGA3AlgorithmMulti,
)
from lightwin.optimisation.algorithms.predefined_solution import (
PredefinedSolution,
)
from lightwin.optimisation.algorithms.simulated_annealing import (
SimulatedAnnealing,
)
from lightwin.optimisation.design_space.design_space import DesignSpace
from lightwin.optimisation.objective.factory import ObjectiveFactory
from lightwin.util.typing import OPTIMIZATION_STATUS
#: Maps the ``optimisation_algorithm`` key in the ``TOML`` file to the actual
#: :class:`.OptimisationAlgorithm` we use.
ALGORITHM_SELECTOR: dict[str, ABCMeta] = {
"bayesian_optimization": BayesianOptimizationLW,
"differential_evolution": DifferentialEvolution,
"downhill_simplex": DownhillSimplex,
"downhill_simplex_penalty": DownhillSimplexPenalty,
"experimental": BayesianOptimizationLW,
"explorator": Explorator,
"least_squares": LeastSquares,
"least_squares_penalty": LeastSquaresPenalty,
"nelder_mead": DownhillSimplex,
"nelder_mead_penalty": DownhillSimplexPenalty,
"NSGA-III": NSGA3Algorithm,
"NSGA-III Multi-threaded": NSGA3AlgorithmMulti,
"simulated_annealing": SimulatedAnnealing,
}
#: Implemented optimization algorithms.
ALGORITHMS_T = Literal[
"bayesian_optimization",
"differential_evolution",
"downhill_simplex",
"downhill_simplex_penalty",
"experimental",
"explorator",
"least_squares",
"least_squares_penalty",
"nelder_mead",
"nelder_mead_penalty",
"NSGA-III",
"NSGA-III Multi-threaded",
"simulated_annealing",
]
[docs]
class OptimisationAlgorithmFactory:
"""Holds methods to easily create :class:`.OptimisationAlgorithm`."""
[docs]
def __init__(
self,
opti_method: ALGORITHMS_T,
beam_calculator: BeamCalculator,
reference_simulation_output: SimulationOutput,
accelerator_id: str,
**wtf: Any,
) -> None:
"""Save properties common to every optimization algorithhm.
Parameters
----------
opti_method :
Name of the desired optimisation algorithm.
beam_calculator :
Object that will be used to compute propagation of the beam.
reference_simulation_output :
Simulation of the nominal accelerator.
accelerator_id :
Associated solution :attr:`.Accelerator.id`. Looks like:
``0000001_Solution``.
kwargs :
Other keyword arguments that will be passed to the
:class:`.OptimisationAlgorithm`.
"""
self._class = ALGORITHM_SELECTOR[opti_method]
self._beam_calculator = beam_calculator
self._wtf = wtf
self._reference_simulation_output = reference_simulation_output
self._accelerator_id = accelerator_id
[docs]
def create(
self,
compensating_elements: Collection[Element],
objective_factory: ObjectiveFactory,
design_space: DesignSpace,
subset_elts: ListOfElements,
) -> OptimisationAlgorithm:
"""Instantiate an optimisation algorithm for a given fault."""
default_kwargs = self._make_default_kwargs(
compensating_elements, objective_factory, design_space, subset_elts
)
self._log_common_keys(self._wtf, default_kwargs)
final_kwargs = {**default_kwargs, **self._wtf}
algorithm = self._class(**final_kwargs)
return algorithm
[docs]
def create_from_preset(
self,
compensating_elements: Collection[Element],
objective_factory: ObjectiveFactory,
design_space: DesignSpace,
subset_elts: ListOfElements,
predefined_cavity_settings: SetOfCavitySettings,
predefined_simulation_output: SimulationOutput | None = None,
) -> PredefinedSolution:
"""Instantiate a fake optimization algorithm bypassing the solver."""
default_kwargs = self._make_default_kwargs(
compensating_elements, objective_factory, design_space, subset_elts
)
return PredefinedSolution(
predefined_cavity_settings=predefined_cavity_settings,
predefined_simulation_output=predefined_simulation_output,
**default_kwargs,
)
[docs]
def _make_default_kwargs(
self,
compensating_elements: Collection[Element],
objective_factory: ObjectiveFactory,
design_space: DesignSpace,
subset_elts: ListOfElements,
) -> dict[str, Any]:
"""Build default arguments for :class:`.OptimisationAlgorithm`.
The kwargs for :class:`.OptimisationAlgorithm` that are defined in
:attr:`.Fault.optimisation_algorithm` will override the ones defined
here.
Returns
-------
A dictionary of keyword arguments for the initialisation of
:class:`.OptimisationAlgorithm`.
"""
def compute_beam_propagation(
cavity_settings: Mapping[FieldMap, CavitySettings] | None,
**kwargs,
):
"""Wrap propagation of the beam.
Parameters
----------
cavity_settings :
Maps compensating cavities with the settings to be tried.
"""
set_of_cavity_settings = SetOfCavitySettings.from_incomplete_set(
compensating_cavity_settings=cavity_settings,
cavities=subset_elts.cavities(superposed="remove"),
optimization_status="in progress",
)
return self._beam_calculator.run_with_this(
accelerator_id=self._accelerator_id,
set_of_cavity_settings=set_of_cavity_settings,
elts=subset_elts,
**kwargs,
)
default_kwargs: dict[str, Any] = {
"compensating_elements": compensating_elements,
"objective_factory": objective_factory,
"design_space": design_space,
"compute_beam_propagation": compute_beam_propagation,
"cavity_settings_factory": self._beam_calculator.cavity_settings_factory,
"reference_simulation_output": self._reference_simulation_output,
}
return default_kwargs
[docs]
def _log_common_keys(
self, user_kwargs: dict[str, Any], default_kwargs: dict[str, Any]
) -> None:
"""Log when user-provided and default kwargs overlap.
Parameters
----------
user_kwargs :
kwargs as defined in the :attr:`.Fault.optimisation_algorithm`
(they have precedence).
default_kwargs :
kwargs as defined in the `_optimisation_algorithm_kwargs` (they
will be overriden as they are considered as "default" or "fallback"
values).
"""
overlap = user_kwargs.keys() & default_kwargs.keys()
if not overlap:
return
logging.info(
"Overlapping OptimisationAlgorithm kwargs detected:\n"
f"{', '.join(overlap)}. User-provided values (from FaultScenario) "
"will override defaults."
)