Source code for lightwin.optimisation.objective.objective

"""Define a class to hold optimisation objective with its ideal value."""

import logging
from abc import ABC, abstractmethod
from collections.abc import Collection
from typing import Any, Self, Sequence

from numpy.typing import NDArray

from lightwin.beam_calculation.simulation_output.simulation_output import (
    SimulationOutput,
)
from lightwin.core.beam_parameters.helper import mismatch_from_arrays
from lightwin.util.typing import (
    GETTABLE_BEAM_PARAMETERS,
    GETTABLE_BEAM_PARAMETERS_T,
    GETTABLE_SIMULATION_OUTPUT,
    GETTABLE_SIMULATION_OUTPUT_T,
)


[docs] class Objective(ABC): """Hold an objective and methods to evaluate it. .. todo:: Should this object also store its final value? """ #: List of authorized values for the ``get_key``. Checked by the #: :meth:`._check_get_arguments` method _gettable: Collection[str] = GETTABLE_SIMULATION_OUTPUT
[docs] def __init__( self, name: str, weight: float, get_key: GETTABLE_SIMULATION_OUTPUT_T, get_kwargs: dict[str, Any], ideal_value: float | tuple[float, float] | None, descriptor: str | None = None, ) -> None: """Hold an objective and methods to evaluate it. Parameters ---------- name : A short string to describe the objective and access to it. weight : A scaling constant to set the weight of current objective. get_key : Name of the quantity to get. get_kwargs : Keyword arguments for the :meth:`.SimulationOutput.get` method. We do not enforce its validity, but in general you will want to define the keys ``elt`` and ``pos``. If objective concerns a phase, you may want to precise the ``to_deg`` key. You also should explicit the ``to_numpy`` key. ideal_value : Ideal value to match. - It is a float when we want to be as close as possible of a value. - It is a tuple of floats when we want to be within two bounds. - It is not defined (``None``) when we want to minimize or maximize an objective. descriptor : A longer string to explain the objective. """ get_key, get_kwargs = self._check_get_arguments(get_key, get_kwargs) #: Short string describing the objective. self.name: str = name #: Weight :math:`w` of current objective. self.weight: float = weight #: Name of the quantity to get from :class:`.SimulationOutput`, self.get_key: GETTABLE_SIMULATION_OUTPUT_T = get_key #: Keyword arguments for the :meth:`.SimulationOutput.get` method. self.get_kwargs: dict[str, Any] = get_kwargs #: Ideal value to match. #: #: - It is a float when we want to be as close as possible of a value. #: - It is a tuple of floats when we want to be within two bounds. #: - It is not defined (``None``) when we want to minimize or maximize #: an objective. self.ideal_value: Any = ideal_value self.descriptor = " ".join((descriptor or "").split()) #: Residual value at the end of the optimization process. self.residual: float
[docs] def __str__(self) -> str: """Give objective information value.""" message = self.position_nature() if isinstance(self.ideal_value, float): message += f"{self.ideal_value:+.14e}" return message if isinstance(self.ideal_value, tuple): message += ( f"{self.ideal_value[0]:+.2e} ~ {self.ideal_value[1]:+.2e}" ) return message if self.ideal_value is None: message += f"{'None': ^21}" return message return message
[docs] def position_nature(self) -> str: """Tell nature and position of objective.""" message = f"{self.get_key:>23}" elt = str(self.get_kwargs.get("elt", "NA")) message += f" @elt {elt:>5}" pos = str(self.get_kwargs.get("pos", "NA")) message += f" ({pos:>3}) | {self.weight:>5} | " return message
[docs] def _value_getter( self, simulation_output: SimulationOutput, handle_missing_elt: bool = False, ) -> float: """Get desired value using :meth:`.SimulationOutput.get` method. .. seealso:: :func:`.simulation_output.factory._element_to_index` Parameters ---------- simulation_output : Object to ``get`` ``self.get_key`` from. handle_missing_elt : Automatically look for an equivalent :class:`.Element` when the current one is not in :class:`.SimulationOutput`. Set it to ``True`` when calculating reference value (reference :class:`.Element` is not in compensating list of elements). """ return simulation_output.get( self.get_key, **self.get_kwargs, handle_missing_elt=handle_missing_elt, )
[docs] def _check_get_arguments( self, get_key: GETTABLE_SIMULATION_OUTPUT_T, get_kwargs: dict[str, Any], ) -> tuple[GETTABLE_SIMULATION_OUTPUT_T, dict[str, Any]]: """Check validity of ``get_args``, ``get_kwargs``. In general, residuals evaluation relies on a :meth:`.SimulationOutput.get` method. This method uses ``get_args`` and ``get_kwargs``; we perform here some basic checks. """ if get_key not in self._gettable: logging.warning( f"{get_key = } may not be gettable by SimulationOutput.get " f"method. Authorized values are:\n{self._gettable = }" ) advised_keys = ["elt", "pos", "to_numpy"] if "phi" in get_key: advised_keys.append("to_deg") for key in advised_keys: if key in get_kwargs: continue logging.warning( f"{key = } is recommended to avoid undetermined behavior but " "was not found." ) return get_key, get_kwargs
[docs] @staticmethod def str_header() -> str: """Give a header to explain what :meth:`__str__` returns.""" header = f"{'Objective': ^40} | {'wgt.':>5} | " header += f"{'ideal value': ^21}" return header
[docs] @staticmethod def str_header_solved() -> str: """Give a header to explain, after optimization, what is printed.""" return f"{Objective.str_header()} | {'final residuals': ^21}"
[docs] @abstractmethod def _compute_residuals(self, objective_value: Any) -> float: """Compute residual (loss), *ie* what we want to minimize. In general, you will want to call this function from :meth:`.Objective.evaluate`. Parameters ---------- objective_value : Value of :attr:`.Objective.name`, taken from a :class:`.SimulationOutput`. Returns ------- residual for current objective, scaled by :attr:`.Objective.weight`. """
[docs] def evaluate(self, simulation_output: SimulationOutput) -> float: """Get desired value from ``simulation_output`` and compute residuals. Parameters ---------- simulation_output : Object containing simulation results of the broken linac. Returns ------- Residual for current objective, scaled by :attr:`.Objective.weight`. """ objective_value = self._value_getter(simulation_output) return self._compute_residuals(objective_value)
[docs] class RetrieveArbitrary(Objective): """Retrieve arbitrary value given by user. You can also use it to minimize or maximize an objective. """
[docs] def __init__( self, name: str, weight: float, get_key: GETTABLE_SIMULATION_OUTPUT_T, get_kwargs: dict[str, Any], ideal_value: float, descriptor: str | None = None, ) -> None: """ Set complementary :meth:`.SimulationOutput.get` flags, reference value. Parameters ---------- name : A short string to describe the objective and access to it. weight : A scaling constant to set the weight of current objective. get_key : Name of the quantity to get. get_kwargs : Keyword arguments for the :meth:`.SimulationOutput.get` method. We do not check its validity, but in general you will want to define the keys ``elt`` and ``pos``. If objective concerns a phase, you may want to precise the ``to_deg`` key. You also should explicit the ``to_numpy`` key. ideal_value : The value to retrieve. descriptor : A longer string to explain the objective. """ self.ideal_value: float super().__init__( name=name, weight=weight, get_key=get_key, get_kwargs=get_kwargs, descriptor=descriptor, ideal_value=ideal_value, )
[docs] def _compute_residuals(self, objective_value: float) -> float: return self.weight * abs(objective_value - self.ideal_value)
[docs] class MinimizeDifferenceWithRef(Objective): """A simple difference at a given point between ref and fix."""
[docs] def __init__( self, name: str, weight: float, get_key: GETTABLE_SIMULATION_OUTPUT_T, get_kwargs: dict[str, Any], reference: SimulationOutput, descriptor: str | None = None, ) -> None: """ Set complementary :meth:`.SimulationOutput.get` flags, reference value. Parameters ---------- name : A short string to describe the objective and access to it. weight : A scaling constant to set the weight of current objective. get_key : Name of the quantity to get. get_kwargs : Keyword arguments for the :meth:`.SimulationOutput.get` method. We do not check its validity, but in general you will want to define the keys ``elt`` and ``pos``. If objective concerns a phase, you may want to precise the ``to_deg`` key. You also should explicit the ``to_numpy`` key. reference : The reference simulation output from which the ideal value will be taken. descriptor : A longer string to explain the objective. """ self.ideal_value: float super().__init__( name=name, weight=weight, get_key=get_key, get_kwargs=get_kwargs, descriptor=descriptor, ideal_value=reference.get( get_key, **get_kwargs, handle_missing_elt=True ), ) self._check_ideal_value()
[docs] def _check_ideal_value(self) -> None: """Assert the the reference value is a float.""" if not isinstance(self.ideal_value, float): logging.warning( f"Tried to get {self.get_key} with {self.get_kwargs}, which " f"returned {self.ideal_value} instead of a float." )
[docs] def _compute_residuals(self, objective_value: float) -> float: return self.weight * abs(objective_value - self.ideal_value)
[docs] class MinimizeMismatch(Objective): """Minimize a mismatch factor.""" _gettable = GETTABLE_BEAM_PARAMETERS
[docs] def __init__( self, name: str, weight: float, get_key: GETTABLE_BEAM_PARAMETERS_T, get_kwargs: dict[str, Any], reference: SimulationOutput, descriptor: str | None = None, ) -> None: """ Set complementary :meth:`.SimulationOutput.get` flags, reference value. Parameters ---------- name : A short string to describe the objective and access to it. weight : A scaling constant to set the weight of current objective. get_key : Must contain 'twiss' plus the name of a phase-space, or simply 'twiss' and the phase-space is defined in ``get_kwargs``. get_kwargs : Keyword arguments for the :meth:`.SimulationOutput.get` method. We do not check its validity, but in general you will want to define the keys ``elt`` and ``pos``. You should also define the ``phase_space_name`` key if it is not defined in the ``get_key``. reference : The reference simulation output from which the Twiss parameters will be taken. descriptor : A longer string to explain the objective. """ self.ideal_value: float self.get_key: GETTABLE_BEAM_PARAMETERS_T super().__init__( name=name, weight=weight, get_key=get_key, get_kwargs=get_kwargs, descriptor=descriptor, ideal_value=0.0, ) self._twiss_ref = self._twiss_getter(reference)
[docs] def _check_get_arguments( self, get_key: GETTABLE_SIMULATION_OUTPUT_T, get_kwargs: dict[str, Any] ) -> tuple[GETTABLE_SIMULATION_OUTPUT_T, dict[str, Any]]: """Add default values if necessary.""" if "twiss" not in get_key: logging.warning( "The get_key should contain 'twiss'. Taking 'twiss' and " "setting phase space to zdelta." ) get_key = "twiss" get_kwargs["phase_space_name"] = "zdelta" return super()._check_get_arguments(get_key, get_kwargs)
[docs] def _twiss_getter(self, simulation_output: SimulationOutput) -> NDArray: """Get desired value using :meth:`.SimulationOutput.get` method.""" return simulation_output.beam_parameters.get( self.get_key, **self.get_kwargs )
[docs] def evaluate(self, simulation_output: SimulationOutput) -> float: twiss_fix = self._twiss_getter(simulation_output) return self._compute_residuals(twiss_fix)
[docs] def _compute_residuals(self, objective_value: NDArray) -> float: res = mismatch_from_arrays(self._twiss_ref, objective_value)[0] return self.weight * res
[docs] class QuantityIsBetween(Objective): """Quantity must be within some bounds."""
[docs] def __init__( self, name: str, weight: float, get_key: GETTABLE_SIMULATION_OUTPUT_T, get_kwargs: dict[str, Any], limits: tuple[float, float], descriptor: str | None = None, loss_function: str | None = None, ) -> None: """ Set complementary :meth:`.SimulationOutput.get` flags, reference value. Parameters ---------- name : A short string to describe the objective and access to it. weight : A scaling constant to set the weight of current objective. get_key : Name of the quantity to get. get_kwargs : Keyword arguments for the :meth:`.SimulationOutput.get` method. We do not check its validity, but in general you will want to define the keys ``elt`` and ``pos``. If objective concerns a phase, you may want to precise the ``to_deg`` key. You also should explicit the ``to_numpy`` key. limits : Lower and upper bound for the value. loss_function : Indicates how the residuals are handled when the quantity is outside the limits. Currently not implemented. """ self.ideal_value: tuple[float, float] super().__init__( name=name, weight=weight, get_key=get_key, get_kwargs=get_kwargs, ideal_value=limits, descriptor=descriptor, ) if loss_function is not None: logging.warning("Loss functions not implemented.")
[docs] @classmethod def relative_to_reference( cls, name: str, weight: float, get_key: GETTABLE_SIMULATION_OUTPUT_T, get_kwargs: dict[str, Any], relative_limits: tuple[float, float], reference_value: float, descriptor: str | None = None, loss_function: str | None = None, ) -> Self: r""" Set complementary :meth:`.SimulationOutput.get` flags, reference value. Parameters ---------- name : A short string to describe the objective and access to it. weight : A scaling constant to set the weight of current objective. get_key : Name of the quantity to get. get_kwargs : Keyword arguments for the :meth:`.SimulationOutput.get` method. We do not check its validity, but in general you will want to define the keys ``elt`` and ``pos``. If objective concerns a phase, you may want to precise the ``to_deg`` key. You also should explicit the ``to_numpy`` key. relative_limits : Lower and upper bound for the value, in :unit:`\%` wrt ``reference_value``. First value should be lower than :math:`100\%`, second value higher than :math:`100\%`. reference_value : Ideal value. loss_function : Indicates how the residuals are handled when the quantity is outside the limits. Currently not implemented. """ assert relative_limits[0] <= 100.0 and relative_limits[1] >= 100.0, ( f"{relative_limits = } but should look like `(80, 135)` (which " "means: objective must be 80% and 135% of reference value." ) limits: tuple[float, float] limits = ( reference_value * 1e-2 * relative_limits[0], reference_value * 1e-2 * relative_limits[1], ) if reference_value <= 0.0: logging.info( f"{reference_value = } is negative. Inverting bounds to keep " "limits[0] < limits[1]." ) limits = (limits[1], limits[0]) return cls( name=name, weight=weight, get_key=get_key, get_kwargs=get_kwargs, limits=limits, descriptor=descriptor, loss_function=loss_function, )
[docs] def __str__(self) -> str: """Give objective information value.""" message = self.position_nature() message += f"{self.ideal_value[0]:+.2e} ~ {self.ideal_value[1]:+.2e}" # type: ignore return message
[docs] def _compute_residuals(self, objective_value: float) -> float: r"""Compute residual (loss), *ie* what we want to minimize. This method applies a quadratic penalty if the value lies outside the target interval defined by ``self.ideal_value``. No penalty is applied when the value is within the interval. Parameters ---------- objective_value : Value of :attr:`.Objective.name`, taken from a :class:`.SimulationOutput`. Returns ------- residual for current objective, scaled by :attr:`.Objective.weight`. The loss function is defined as: - :math:`0` if :math:`x_l \leq x \leq x_u`, *ie* if ``objective_value`` is within the bounds defined by :attr:`.Objective.ideal_value` - :math:`w \times (x - x_{l\,u})^2` otherwise, where :math:`x_{l,u}` is the violated boundary and :math:`w` is :attr:`.Objective.weight`. """ if objective_value < self.ideal_value[0]: return self.weight * (objective_value - self.ideal_value[0]) ** 2 if objective_value > self.ideal_value[1]: return self.weight * (objective_value - self.ideal_value[1]) ** 2 return 0.0
[docs] def str_objectives(objectives: Sequence[Objective]) -> str: """Return a string describing several objectives.""" info = [str(objective) for objective in objectives] info.insert(0, "=" * 100) info.insert(1, Objective.str_header()) info.insert(2, "-" * 100) info.append("=" * 100) return "\n".join(info)
[docs] def str_objectives_solved(objectives: Sequence[Objective]) -> str: """Return a string describing objectives results.""" try: info = [ f"{str(objective)} | {objective.residual:+.14e}" for objective in objectives ] except AttributeError as e: logging.error( "Something went wrong when trying to access the residuals of " f"objectives.\n{e}" ) return "No info on objective values, cf above message." info.insert(0, "=" * 100) info.insert(1, Objective.str_header_solved()) info.insert(2, "-" * 100) info.append("=" * 100) return "\n".join(info)