Source code for lightwin.visualization.plot

"""Define a library to produce all these nice plots.

.. todo::
    better detection of what is a multiparticle simulation and what is not.
    Currently looking for "'partran': 0" in the name of the solver, making the
    assumption that multipart is the default. But it depends on the .ini...
    update: just use .is_a_multiparticle_simulation

.. todo::
    Fix when there is only one accelerator to plot.

.. todo::
    Different plot according to dimension of FieldMap, or according to if it
    accelerates or not (ex when quadrupole defined by a field map)

"""

import logging
from collections.abc import Collection
from pathlib import Path
from typing import Any, Literal, Sequence

import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.typing import ColorType
from palettable.colorbrewer.qualitative import Dark2_8  # type: ignore

import lightwin.util.dicts_output as dic
from lightwin.core.accelerator.accelerator import Accelerator
from lightwin.failures.fault import Fault
from lightwin.failures.fault_scenario import FaultScenario
from lightwin.util.typing import GETTABLE_SIMULATION_OUTPUT_T
from lightwin.visualization import structure
from lightwin.visualization.data_getter import all_accelerators_data
from lightwin.visualization.helper import (
    X_AXIS_T,
    create_fig_if_not_exists,
    savefig,
)
from lightwin.visualization.optimization import mark_objectives_position

font = {"family": "serif"}
plt.rc("font", **font)
plt.rcParams["axes.prop_cycle"] = cycler(color=Dark2_8.mpl_colors)

FALLBACK_PRESETS = {"x_axis": "z_abs", "plot_section": True, "sharex": True}
PLOT_PRESETS = {
    "acceptance": {
        "x_axis": "elt_idx",
        "all_y_axis": ("acceptance_phi", "acceptance_energy", "struct"),
        "num": 28,
        "symmetric_plot": True,
    },
    "cav": {
        "x_axis": "elt_idx",
        "all_y_axis": ("v_cav_mv", "phi_s", "struct"),
        "num": 23,
    },
    "emittance": {
        "x_axis": "z_abs",
        "all_y_axis": ("eps_phiw", "struct"),
        "num": 24,
    },
    "energy": {
        "x_axis": "z_abs",
        "all_y_axis": ("w_kin", "w_kin_err", "struct"),
        "num": 21,
    },
    "envelopes": {
        "x_axis": "z_abs",
        "all_y_axis": (
            "envelope_pos_phiw",
            "envelope_energy_phiw",
            "struct",
        ),
        "num": 26,
        "symmetric_plot": True,
    },
    "mismatch_factor": {
        "x_axis": "z_abs",
        "all_y_axis": ("mismatch_factor_zdelta", "struct"),
        "num": 27,
    },
    "phase": {
        "x_axis": "z_abs",
        "all_y_axis": ("phi_abs", "phi_abs_err", "struct"),
        "num": 22,
    },
    "transfer_matrices": {
        "x_axis": "z_abs",
        "all_y_axis": (
            "r_zdelta_11",
            "r_zdelta_12",
            "r_zdelta_21",
            "r_zdelta_22",
        ),
        "num": 29,
    },
    "twiss": {
        "x_axis": "z_abs",
        "all_y_axis": ("alpha_phiw", "beta_phiw", "struct"),
        "num": 25,
    },
}
ERROR_PRESETS = {
    "w_kin_err": {"scale": 1.0, "diff": "simple"},
    "phi_abs_err": {"scale": 1.0, "diff": "simple"},
}
#: List of implemented presets for the plots
ALLOWED_PLOT_PRESETS = list(PLOT_PRESETS.keys())

# The one you generally want
ERROR_REFERENCE = "ref accelerator (1st solv w/ 1st solv, 2nd w/ 2nd)"

# These two are useful when you want to study the differences between
# two solvers
# ERROR_REFERENCE = "ref accelerator (1st solver)"
# ERROR_REFERENCE = "ref accelerator (2nd solver)"


# =============================================================================
# Front end
# =============================================================================
[docs] def factory( accelerators: dict[int, list[Accelerator]], plots: dict[str, Any], save_fig: bool = True, clean_fig: bool = True, fault_scenarios: Sequence[FaultScenario] | None = None, only_solver_id: Collection[str] | str | None = None, **kwargs, ) -> list[Figure]: """Create all the desired plots. Parameters ---------- accelerators : Mapping of scenario index to a list of accelerators for that scenario. Key ``0`` is always the reference (single element). All other keys are "fixed" scenarios, each potentially holding several alternative accelerators that will be overlaid on the same figure:: { 0: [reference_accelerator], 1: [fixed_01], 2: [fixed_02, alternative_02], } plots : The plot ``TOML`` table. save_fig : If Figures should be saved. clean_fig : If Figures should be cleaned between two calls of this function. fault_scenarios : If provided, the position of the :class:`.Objective` will also appear on plots. kwargs : Other tables from the ``TOML`` configuration file. Returns ------- The created figures. """ if 0 not in accelerators: raise ValueError("accelerators must contain key 0 (reference).") if len(accelerators[0]) != 1: raise ValueError( "Reference scenario (key 0) must hold exactly one accelerator." ) if clean_fig and not save_fig and len(accelerators) > 2: logging.warning( "You will only see the plots of the last scenario; previous " "figures will be erased without saving." ) plots_presets, plots_kwargs = ( _separate_plot_presets_from_plot_modificators(plots) ) merged_kwargs = kwargs | plots_kwargs plot_groups = _build_plot_groups(accelerators) figs: list[Figure] = [ _plot_preset( preset, *accelerators_to_plot, save_fig=save_fig, clean_fig=clean_fig, fault_scenarios=fault_scenarios, only_solver_id=only_solver_id, **_proper_kwargs(preset, merged_kwargs), ) for accelerators_to_plot in plot_groups for preset, plot_me in plots_presets.items() if plot_me ] return figs
[docs] def _separate_plot_presets_from_plot_modificators( plots: dict[str, Any], ) -> tuple[dict[str, bool], dict[str, Any]]: """Separate the config entries corresponding to the name of a plot. Parameters ---------- plots : Dictionary holding the plot configuration. Returns ------- plot_presets : Subset of ``plots``, with only the keys that can be found in :data:`ALLOWED_PLOT_PRESETS`. Indicates which plots presets will be plotted: ``"cav"``, ``"emittance"``... plot_kwargs : Subset of ``plots``, with only the keys corresponding to a plot modificator, eg ``"add_objectives"``. """ plot_presets: dict[str, bool] = {} plot_kwargs: dict[str, Any] = {} for key, value in plots.items(): if key in PLOT_PRESETS: plot_presets[key] = value continue plot_kwargs[key] = value return plot_presets, plot_kwargs
[docs] def _build_plot_groups( accelerators: dict[int, list[Accelerator]], ) -> list[list[Accelerator]]: """Build the groups of accelerators to plot together. Each group will produce one figure per preset. The reference accelerator is always first. Scenario 0 produces a group of just ``[ref_acc]``. Other scenarios produce ``[ref_acc, *computed_accs]``, and are skipped entirely if none of their accelerators are computed yet. Parameters ---------- accelerators : The full scenario mapping as passed to :func:`factory`. Must have the ``0`` key, the corresponding value must be a list containing only the reference |A|. Returns ------- List of accelerator groups, one per scenario to plot. """ plot_groups: list[list[Accelerator]] = [] ref_acc = accelerators[0][0] for scenario_idx, scenario_accs in accelerators.items(): if scenario_idx == 0: continue computed = [acc for acc in scenario_accs if acc.is_computed()] n_skipped = len(scenario_accs) - len(computed) if n_skipped > 0: logging.info( f"Scenario {scenario_idx}: skipping {n_skipped} uncomputed " f"accelerator(s) out of {len(scenario_accs)}." ) if len(computed) == 0: logging.info( f"Scenario {scenario_idx}: no computed accelerators, skipping " "all presets for this scenario." ) continue plot_groups.append([ref_acc, *computed]) if len(plot_groups) == 0: plot_groups.append([ref_acc]) return plot_groups
[docs] def _plot_preset( preset: str, *accelerators_to_plot, all_y_axis: list[GETTABLE_SIMULATION_OUTPUT_T | Literal["struct"]], x_axis: X_AXIS_T = "z_abs", save_fig: bool = True, clean_fig: bool = True, add_objectives: bool = False, fault_scenarios: Sequence[list[Fault]] | None = None, usr_kwargs: dict[str, Any] | None = None, get_kwargs: dict[str, bool] | None = None, symmetric_plot: bool = False, only_solver_id: Collection[str] | str | None = None, **kwargs, ) -> Figure: """Plot a preset showing reference and all fixed alternatives for one scenario. Parameters ---------- preset : Key of :data:`ALLOWED_PLOT_PRESETS`. *accelerators_to_plot : Accelerators to plot. First is always the reference. May contain only the reference (scenario 0), or reference + one or more fixed alternatives. all_y_axis : Name of all the y axis. x_axis : Name of the x axis. save_fig : To save Figures or not. Figure is saved to the path of ``fix_accs[0]``. add_objectives : To add the position of objectives to the plots; if True, the ``fault_scenarios`` must be provided. fault_scenarios : To plot the objectives, if ``add_objectives == True``. usr_kwargs : User-defined ``kwargs``, passed to the |axplot| method. get_kwargs : Keyword arguments for the :meth:`.SimulationOutput.get` methods. symmetric_plot : If plot should be symmetric around the x axis. only_solver_id : If set, we plot only data obtained with this solver(s). Must be :attr:`.BeamCalculator.id` (or, equivalently, a key(s) in :attr:`.Accelerator.simulation_outputs`). Typical values: ``"0_Envelope1D"`` or ``"1_TraceWin"``. **kwargs : Holds all complementary data on the plots. """ fig, axx = create_fig_if_not_exists( len(all_y_axis), clean_fig=clean_fig, **kwargs ) colors: dict[str, ColorType] | None = None for i, (ax, y_axis) in enumerate(zip(axx, all_y_axis)): try: _make_a_subplot( ax, x_axis, y_axis, colors, *accelerators_to_plot, get_kwargs=get_kwargs, symmetric_plot=symmetric_plot, only_solver_id=only_solver_id, **(usr_kwargs or {}), ) except ValueError as e: logging.error( f"A ValueError was raised when trying to plot {y_axis} vs " f"{x_axis}. This likely an error caused by inconsistent " f"x and y data.\n{e}" ) raise e if i == 0: colors = _used_colors(ax) if add_objectives: mark_objectives_position(ax, fault_scenarios, y_axis, x_axis) axx[0].legend() axx[-1].set_xlabel(dic.markdown[x_axis]) if save_fig: file = Path( accelerators_to_plot[-1].get("accelerator_path"), f"{preset}.png" ) savefig(fig, file) return fig
[docs] def _proper_kwargs(preset: str, kwargs: dict[str, Any]) -> dict[str, Any]: """Merge dicts, priority kwargs > PLOT_PRESETS > FALLBACK_PRESETS. We also add a ``"usr_kwargs"`` key holding additional keywords, that will be passed to |axplot|. """ merged = FALLBACK_PRESETS | PLOT_PRESETS[preset] | kwargs if "kwargs" in merged: merged["usr_kwargs"] = merged.pop("kwargs") return merged
[docs] def _used_colors(axe: Axes) -> dict[str, ColorType]: """Associate every line label to a color.""" lines = axe.get_lines() colors = {str(line.get_label()): line.get_color() for line in lines} return colors
[docs] def _y_label(y_axis: str) -> str: """Set the proper y axis label.""" if "_err" in y_axis: key = ERROR_PRESETS[y_axis]["diff"] y_label = dic.markdown["err_" + key] return y_label y_label = dic.markdown[y_axis] return y_label
# Actual interface with matplotlib
[docs] def _make_a_subplot( axe: Axes, x_axis: X_AXIS_T, y_axis: GETTABLE_SIMULATION_OUTPUT_T | Literal["struct"], colors: dict[str, ColorType] | None, *accelerators: Accelerator, plot_section: bool = True, symmetric_plot: bool = False, get_kwargs: dict[str, bool] | None = None, only_solver_id: Collection[str] | str | None = None, **usr_kwargs, ) -> None: """Get proper data and plot it on an Axe. Parameters ---------- axe : Object on which to add plot data. x_axis : Nature of x axis. y_axis : What to plot. colors : Holds the line labels from previous plots and associate it to their colors. accelerators : Objects from which we take ``y_axis``. plot_section : To outline the different sections in the background of the plots. symmetric_plot : If a symmetric plot (wrt x axis) should be added. get_kwargs : Keyword arguments for the :meth:`.SimulationOutput.get` method. only_solver_id : If set, we plot only data obtained with this solver(s). Must be :attr:`.BeamCalculator.id` (or, equivalently, a key(s) in :attr:`.Accelerator.simulation_outputs`). Typical values: ``"0_Envelope1D"`` or ``"1_TraceWin"``. usr_kwargs : User-defined ``kwargs``, passed to the |axplot| method. """ if plot_section: structure.outline_sections(accelerators[0].elts, axe, x_axis=x_axis) if y_axis == "struct": return structure.plot_structure( accelerators[-1].elts, axe, x_axis=x_axis ) x_data, y_data, plt_kwargs = all_accelerators_data( x_axis, y_axis, *accelerators, error_presets=ERROR_PRESETS, error_reference=ERROR_REFERENCE, only_solver_id=only_solver_id, **(get_kwargs or {}), ) # Alternate markers for the "cav" preset markers = ("o", "^") marker_index = 0 for x, y, _plt_kwargs in zip(x_data, y_data, plt_kwargs): if y_axis in ("v_cav_mv", "phi_s"): _plt_kwargs["marker"] = markers[marker_index] marker_index = (marker_index + 1) % len(markers) if colors is not None and _plt_kwargs["label"] in colors: _plt_kwargs["color"] = colors[_plt_kwargs["label"]] (line,) = axe.plot(x, y, **_plt_kwargs | usr_kwargs) if symmetric_plot: symmetric_kwargs = _plt_kwargs | { "color": line.get_color(), "label": None, } axe.plot(x, -y, **symmetric_kwargs) axe.grid(True) axe.set_ylabel(_y_label(y_axis))
[docs] def plot_pty_with_data_tags(ax, x, y, idx_list, tags=True): """Plot y vs x. Data at idx_list are magnified with bigger points and data tags. """ (line,) = ax.plot(x, y) ax.scatter(x[idx_list], y[idx_list], color=line.get_color()) if tags: n = len(idx_list) for i in range(n): txt = ( str(np.round(x[idx_list][i], 4)) + "," + str(np.round(y[idx_list][i], 4)) ) ax.annotate(txt, (x[idx_list][i], y[idx_list[i]]), size=8)