"""Define a library to produce all these nice plots.
.. todo::
better detection of what is a multiparticle simulation and what is not.
Currently looking for "'partran': 0" in the name of the solver, making the
assumption that multipart is the default. But it depends on the .ini...
update: just use .is_a_multiparticle_simulation
.. todo::
Fix when there is only one accelerator to plot.
.. todo::
Different plot according to dimension of FieldMap, or according to if it
accelerates or not (ex when quadrupole defined by a field map)
"""
import logging
from collections.abc import Collection
from pathlib import Path
from typing import Any, Literal, Sequence
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.typing import ColorType
from palettable.colorbrewer.qualitative import Dark2_8 # type: ignore
import lightwin.util.dicts_output as dic
from lightwin.core.accelerator.accelerator import Accelerator
from lightwin.failures.fault import Fault
from lightwin.failures.fault_scenario import FaultScenario
from lightwin.util.typing import GETTABLE_SIMULATION_OUTPUT_T
from lightwin.visualization import structure
from lightwin.visualization.data_getter import all_accelerators_data
from lightwin.visualization.helper import (
X_AXIS_T,
create_fig_if_not_exists,
savefig,
)
from lightwin.visualization.optimization import mark_objectives_position
font = {"family": "serif"}
plt.rc("font", **font)
plt.rcParams["axes.prop_cycle"] = cycler(color=Dark2_8.mpl_colors)
FALLBACK_PRESETS = {"x_axis": "z_abs", "plot_section": True, "sharex": True}
PLOT_PRESETS = {
"acceptance": {
"x_axis": "elt_idx",
"all_y_axis": ("acceptance_phi", "acceptance_energy", "struct"),
"num": 28,
"symmetric_plot": True,
},
"cav": {
"x_axis": "elt_idx",
"all_y_axis": ("v_cav_mv", "phi_s", "struct"),
"num": 23,
},
"emittance": {
"x_axis": "z_abs",
"all_y_axis": ("eps_phiw", "struct"),
"num": 24,
},
"energy": {
"x_axis": "z_abs",
"all_y_axis": ("w_kin", "w_kin_err", "struct"),
"num": 21,
},
"envelopes": {
"x_axis": "z_abs",
"all_y_axis": (
"envelope_pos_phiw",
"envelope_energy_phiw",
"struct",
),
"num": 26,
"symmetric_plot": True,
},
"mismatch_factor": {
"x_axis": "z_abs",
"all_y_axis": ("mismatch_factor_zdelta", "struct"),
"num": 27,
},
"phase": {
"x_axis": "z_abs",
"all_y_axis": ("phi_abs", "phi_abs_err", "struct"),
"num": 22,
},
"transfer_matrices": {
"x_axis": "z_abs",
"all_y_axis": (
"r_zdelta_11",
"r_zdelta_12",
"r_zdelta_21",
"r_zdelta_22",
),
"num": 29,
},
"twiss": {
"x_axis": "z_abs",
"all_y_axis": ("alpha_phiw", "beta_phiw", "struct"),
"num": 25,
},
}
ERROR_PRESETS = {
"w_kin_err": {"scale": 1.0, "diff": "simple"},
"phi_abs_err": {"scale": 1.0, "diff": "simple"},
}
#: List of implemented presets for the plots
ALLOWED_PLOT_PRESETS = list(PLOT_PRESETS.keys())
# The one you generally want
ERROR_REFERENCE = "ref accelerator (1st solv w/ 1st solv, 2nd w/ 2nd)"
# These two are useful when you want to study the differences between
# two solvers
# ERROR_REFERENCE = "ref accelerator (1st solver)"
# ERROR_REFERENCE = "ref accelerator (2nd solver)"
# =============================================================================
# Front end
# =============================================================================
[docs]
def factory(
accelerators: dict[int, list[Accelerator]],
plots: dict[str, Any],
save_fig: bool = True,
clean_fig: bool = True,
fault_scenarios: Sequence[FaultScenario] | None = None,
only_solver_id: Collection[str] | str | None = None,
**kwargs,
) -> list[Figure]:
"""Create all the desired plots.
Parameters
----------
accelerators :
Mapping of scenario index to a list of accelerators for that scenario.
Key ``0`` is always the reference (single element). All other keys are
"fixed" scenarios, each potentially holding several alternative
accelerators that will be overlaid on the same figure::
{
0: [reference_accelerator],
1: [fixed_01],
2: [fixed_02, alternative_02],
}
plots :
The plot ``TOML`` table.
save_fig :
If Figures should be saved.
clean_fig :
If Figures should be cleaned between two calls of this function.
fault_scenarios :
If provided, the position of the :class:`.Objective` will also appear
on plots.
kwargs :
Other tables from the ``TOML`` configuration file.
Returns
-------
The created figures.
"""
if 0 not in accelerators:
raise ValueError("accelerators must contain key 0 (reference).")
if len(accelerators[0]) != 1:
raise ValueError(
"Reference scenario (key 0) must hold exactly one accelerator."
)
if clean_fig and not save_fig and len(accelerators) > 2:
logging.warning(
"You will only see the plots of the last scenario; previous "
"figures will be erased without saving."
)
plots_presets, plots_kwargs = (
_separate_plot_presets_from_plot_modificators(plots)
)
merged_kwargs = kwargs | plots_kwargs
plot_groups = _build_plot_groups(accelerators)
figs: list[Figure] = [
_plot_preset(
preset,
*accelerators_to_plot,
save_fig=save_fig,
clean_fig=clean_fig,
fault_scenarios=fault_scenarios,
only_solver_id=only_solver_id,
**_proper_kwargs(preset, merged_kwargs),
)
for accelerators_to_plot in plot_groups
for preset, plot_me in plots_presets.items()
if plot_me
]
return figs
[docs]
def _separate_plot_presets_from_plot_modificators(
plots: dict[str, Any],
) -> tuple[dict[str, bool], dict[str, Any]]:
"""Separate the config entries corresponding to the name of a plot.
Parameters
----------
plots :
Dictionary holding the plot configuration.
Returns
-------
plot_presets :
Subset of ``plots``, with only the keys that can be found in
:data:`ALLOWED_PLOT_PRESETS`. Indicates which plots presets will be plotted:
``"cav"``, ``"emittance"``...
plot_kwargs :
Subset of ``plots``, with only the keys corresponding to a plot
modificator, eg ``"add_objectives"``.
"""
plot_presets: dict[str, bool] = {}
plot_kwargs: dict[str, Any] = {}
for key, value in plots.items():
if key in PLOT_PRESETS:
plot_presets[key] = value
continue
plot_kwargs[key] = value
return plot_presets, plot_kwargs
[docs]
def _build_plot_groups(
accelerators: dict[int, list[Accelerator]],
) -> list[list[Accelerator]]:
"""Build the groups of accelerators to plot together.
Each group will produce one figure per preset. The reference accelerator
is always first. Scenario 0 produces a group of just ``[ref_acc]``.
Other scenarios produce ``[ref_acc, *computed_accs]``, and are skipped
entirely if none of their accelerators are computed yet.
Parameters
----------
accelerators :
The full scenario mapping as passed to :func:`factory`. Must have the
``0`` key, the corresponding value must be a list containing only
the reference |A|.
Returns
-------
List of accelerator groups, one per scenario to plot.
"""
plot_groups: list[list[Accelerator]] = []
ref_acc = accelerators[0][0]
for scenario_idx, scenario_accs in accelerators.items():
if scenario_idx == 0:
continue
computed = [acc for acc in scenario_accs if acc.is_computed()]
n_skipped = len(scenario_accs) - len(computed)
if n_skipped > 0:
logging.info(
f"Scenario {scenario_idx}: skipping {n_skipped} uncomputed "
f"accelerator(s) out of {len(scenario_accs)}."
)
if len(computed) == 0:
logging.info(
f"Scenario {scenario_idx}: no computed accelerators, skipping "
"all presets for this scenario."
)
continue
plot_groups.append([ref_acc, *computed])
if len(plot_groups) == 0:
plot_groups.append([ref_acc])
return plot_groups
[docs]
def _plot_preset(
preset: str,
*accelerators_to_plot,
all_y_axis: list[GETTABLE_SIMULATION_OUTPUT_T | Literal["struct"]],
x_axis: X_AXIS_T = "z_abs",
save_fig: bool = True,
clean_fig: bool = True,
add_objectives: bool = False,
fault_scenarios: Sequence[list[Fault]] | None = None,
usr_kwargs: dict[str, Any] | None = None,
get_kwargs: dict[str, bool] | None = None,
symmetric_plot: bool = False,
only_solver_id: Collection[str] | str | None = None,
**kwargs,
) -> Figure:
"""Plot a preset showing reference and all fixed alternatives for one
scenario.
Parameters
----------
preset :
Key of :data:`ALLOWED_PLOT_PRESETS`.
*accelerators_to_plot :
Accelerators to plot. First is always the reference. May contain only
the reference (scenario 0), or reference + one or more fixed
alternatives.
all_y_axis :
Name of all the y axis.
x_axis :
Name of the x axis.
save_fig :
To save Figures or not. Figure is saved to the path of ``fix_accs[0]``.
add_objectives :
To add the position of objectives to the plots; if True, the
``fault_scenarios`` must be provided.
fault_scenarios :
To plot the objectives, if ``add_objectives == True``.
usr_kwargs :
User-defined ``kwargs``, passed to the |axplot| method.
get_kwargs :
Keyword arguments for the :meth:`.SimulationOutput.get` methods.
symmetric_plot :
If plot should be symmetric around the x axis.
only_solver_id :
If set, we plot only data obtained with this solver(s). Must be
:attr:`.BeamCalculator.id` (or, equivalently, a key(s) in
:attr:`.Accelerator.simulation_outputs`). Typical values:
``"0_Envelope1D"`` or ``"1_TraceWin"``.
**kwargs :
Holds all complementary data on the plots.
"""
fig, axx = create_fig_if_not_exists(
len(all_y_axis), clean_fig=clean_fig, **kwargs
)
colors: dict[str, ColorType] | None = None
for i, (ax, y_axis) in enumerate(zip(axx, all_y_axis)):
try:
_make_a_subplot(
ax,
x_axis,
y_axis,
colors,
*accelerators_to_plot,
get_kwargs=get_kwargs,
symmetric_plot=symmetric_plot,
only_solver_id=only_solver_id,
**(usr_kwargs or {}),
)
except ValueError as e:
logging.error(
f"A ValueError was raised when trying to plot {y_axis} vs "
f"{x_axis}. This likely an error caused by inconsistent "
f"x and y data.\n{e}"
)
raise e
if i == 0:
colors = _used_colors(ax)
if add_objectives:
mark_objectives_position(ax, fault_scenarios, y_axis, x_axis)
axx[0].legend()
axx[-1].set_xlabel(dic.markdown[x_axis])
if save_fig:
file = Path(
accelerators_to_plot[-1].get("accelerator_path"), f"{preset}.png"
)
savefig(fig, file)
return fig
[docs]
def _proper_kwargs(preset: str, kwargs: dict[str, Any]) -> dict[str, Any]:
"""Merge dicts, priority kwargs > PLOT_PRESETS > FALLBACK_PRESETS.
We also add a ``"usr_kwargs"`` key holding additional keywords, that will
be passed to |axplot|.
"""
merged = FALLBACK_PRESETS | PLOT_PRESETS[preset] | kwargs
if "kwargs" in merged:
merged["usr_kwargs"] = merged.pop("kwargs")
return merged
[docs]
def _used_colors(axe: Axes) -> dict[str, ColorType]:
"""Associate every line label to a color."""
lines = axe.get_lines()
colors = {str(line.get_label()): line.get_color() for line in lines}
return colors
[docs]
def _y_label(y_axis: str) -> str:
"""Set the proper y axis label."""
if "_err" in y_axis:
key = ERROR_PRESETS[y_axis]["diff"]
y_label = dic.markdown["err_" + key]
return y_label
y_label = dic.markdown[y_axis]
return y_label
# Actual interface with matplotlib
[docs]
def _make_a_subplot(
axe: Axes,
x_axis: X_AXIS_T,
y_axis: GETTABLE_SIMULATION_OUTPUT_T | Literal["struct"],
colors: dict[str, ColorType] | None,
*accelerators: Accelerator,
plot_section: bool = True,
symmetric_plot: bool = False,
get_kwargs: dict[str, bool] | None = None,
only_solver_id: Collection[str] | str | None = None,
**usr_kwargs,
) -> None:
"""Get proper data and plot it on an Axe.
Parameters
----------
axe :
Object on which to add plot data.
x_axis :
Nature of x axis.
y_axis :
What to plot.
colors :
Holds the line labels from previous plots and associate it to their
colors.
accelerators :
Objects from which we take ``y_axis``.
plot_section :
To outline the different sections in the background of the plots.
symmetric_plot :
If a symmetric plot (wrt x axis) should be added.
get_kwargs :
Keyword arguments for the :meth:`.SimulationOutput.get` method.
only_solver_id :
If set, we plot only data obtained with this solver(s). Must be
:attr:`.BeamCalculator.id` (or, equivalently, a key(s) in
:attr:`.Accelerator.simulation_outputs`). Typical values:
``"0_Envelope1D"`` or ``"1_TraceWin"``.
usr_kwargs :
User-defined ``kwargs``, passed to the |axplot| method.
"""
if plot_section:
structure.outline_sections(accelerators[0].elts, axe, x_axis=x_axis)
if y_axis == "struct":
return structure.plot_structure(
accelerators[-1].elts, axe, x_axis=x_axis
)
x_data, y_data, plt_kwargs = all_accelerators_data(
x_axis,
y_axis,
*accelerators,
error_presets=ERROR_PRESETS,
error_reference=ERROR_REFERENCE,
only_solver_id=only_solver_id,
**(get_kwargs or {}),
)
# Alternate markers for the "cav" preset
markers = ("o", "^")
marker_index = 0
for x, y, _plt_kwargs in zip(x_data, y_data, plt_kwargs):
if y_axis in ("v_cav_mv", "phi_s"):
_plt_kwargs["marker"] = markers[marker_index]
marker_index = (marker_index + 1) % len(markers)
if colors is not None and _plt_kwargs["label"] in colors:
_plt_kwargs["color"] = colors[_plt_kwargs["label"]]
(line,) = axe.plot(x, y, **_plt_kwargs | usr_kwargs)
if symmetric_plot:
symmetric_kwargs = _plt_kwargs | {
"color": line.get_color(),
"label": None,
}
axe.plot(x, -y, **symmetric_kwargs)
axe.grid(True)
axe.set_ylabel(_y_label(y_axis))