"""Define the function to extract the data to plot.
.. todo::
Fix the TransferMatrix plot with TraceWin solver.
"""
import itertools
import logging
from typing import Any, Callable
import numpy as np
from numpy.typing import NDArray
import lightwin.util.dicts_output as dic
from lightwin.beam_calculation.simulation_output.simulation_output import (
SimulationOutput,
)
from lightwin.core.accelerator.accelerator import Accelerator
from lightwin.util import helper
from lightwin.util.typing import GETTABLE_SIMULATION_OUTPUT_T
from lightwin.visualization.helper import (
X_AXIS_T,
)
[docs]
def all_accelerators_data(
x_axis: X_AXIS_T,
y_axis: GETTABLE_SIMULATION_OUTPUT_T,
*accelerators: Accelerator,
error_presets: dict[str, dict[str, Any]],
error_reference: str,
to_deg: bool = True,
none_to_nan: bool = True,
to_numpy: bool = True,
warn_structure_dependent: bool = False,
**get_kwargs,
) -> tuple[list[np.ndarray], list[np.ndarray], list[dict[str, Any]]]:
"""Get x_data, y_data, kwargs from all Accelerators (<=> for 1 subplot)."""
x_data, y_data, plt_kwargs = [], [], []
key = y_axis
error_plot = y_axis[-4:] == "_err"
if error_plot:
key = y_axis[:-4]
for accelerator in accelerators:
x_dat, y_dat, plt_kw = _single_accelerator_all_simulations_data(
x_axis,
key,
accelerator,
to_deg=to_deg,
none_to_nan=none_to_nan,
to_numpy=to_numpy,
warn_structure_dependent=warn_structure_dependent,
**get_kwargs,
)
x_data += x_dat
y_data += y_dat
plt_kwargs += plt_kw
if error_plot:
fun_error = _error_calculation_function(
y_axis, error_presets=error_presets
)
x_data, y_data, plt_kwargs = _compute_error(
x_data,
y_data,
plt_kwargs,
fun_error,
error_reference=error_reference,
)
plt_kwargs = _avoid_similar_labels(plt_kwargs)
return x_data, y_data, plt_kwargs
[docs]
def _single_accelerator_all_simulations_data(
x_axis: X_AXIS_T,
y_axis: GETTABLE_SIMULATION_OUTPUT_T,
accelerator: Accelerator,
**get_kwargs,
) -> tuple[list[np.ndarray], list[np.ndarray], list[dict[str, Any]]]:
"""Get x_data, y_data, kwargs from all SimulationOutputs of Accelerator."""
x_data, y_data, plt_kwargs = [], [], []
ls = "-"
for solver, simulation_output in accelerator.simulation_outputs.items():
short_solver = solver.split("(")[0]
if simulation_output.is_multiparticle:
short_solver += " (multipart)"
label = f"{accelerator.name} {short_solver}"
x_dat, y_dat, plt_kw = _single_simulation_all_data(
x_axis, y_axis, simulation_output, label=label, **get_kwargs
)
plt_kw["label"] = label
plt_kw["ls"] = ls
ls = "--"
x_data.append(x_dat)
y_data.append(y_dat)
plt_kwargs.append(plt_kw)
return x_data, y_data, plt_kwargs
[docs]
def _single_simulation_all_data(
x_axis: X_AXIS_T,
y_axis: GETTABLE_SIMULATION_OUTPUT_T,
simulation_output: SimulationOutput,
label: str,
**get_kwargs,
) -> tuple[NDArray[np.float64], NDArray[np.float64], dict[str, Any]]:
"""Get x data, y data, kwargs from a SimulationOutput."""
x_data = _single_simulation_data(x_axis, simulation_output, **get_kwargs)
y_data = _single_simulation_data(y_axis, simulation_output, **get_kwargs)
if x_data is None or y_data is None:
if x_data is None:
logging.error(
f"{x_axis} not found in {label}. Setting it to dummy data. "
f"Complete SimulationOutput is:\n{simulation_output}"
)
if y_data is None:
logging.error(
f"{y_axis} not found in {label}. Setting it to dummy data. "
f"Complete SimulationOutput is:\n{simulation_output}"
)
x_data = np.full((10, 1), np.nan)
y_data = np.full((10, 1), np.nan)
return x_data, y_data, {}
if (leny := y_data.shape) != (lenx := x_data.shape):
logging.error(
f"Shape mismatch in {label}: {x_axis} has shape {lenx} while "
f"{y_axis} has shape {leny}. If this is a TransferMatrix plot "
"with TraceWin solver, it is because TraceWin exports one transfer"
" matrix per element while LightWin exports one per thin-lense "
"(FIXME). Also happends with acceptance_phi and TraceWin. Skipping"
f" this plot. Complete SimulationOuptut is:\n{simulation_output}"
)
y_data = np.full_like(x_data, np.nan)
return x_data, y_data, {}
plt_kwargs = dic.plot_kwargs[y_axis].copy()
return x_data, y_data, plt_kwargs
[docs]
def _single_simulation_data(
axis: GETTABLE_SIMULATION_OUTPUT_T,
simulation_output: SimulationOutput,
to_deg: bool = True,
**get_kwargs,
) -> NDArray[np.float64] | None:
"""Get single data array from single SimulationOutput."""
# Patch to avoid envelopes being converted again to degrees
if "envelope_pos" in axis:
to_deg = False
data = simulation_output.get(axis, to_deg=to_deg, **get_kwargs)
return data
[docs]
def _avoid_similar_labels(plt_kwargs: list[dict]) -> list[dict]:
"""Append a number at the end of labels in doublons."""
my_labels = []
for kwargs in plt_kwargs:
label = kwargs["label"]
if label not in my_labels:
my_labels.append(label)
continue
while kwargs["label"] in my_labels:
try:
i = int(label[-1])
kwargs["label"] += str(i + 1)
except ValueError:
kwargs["label"] += "_0"
my_labels.append(kwargs["label"])
return plt_kwargs
# Error related
[docs]
def _error_calculation_function(
y_axis: str,
error_presets: dict[str, dict[str, Any]],
) -> tuple[Callable[[np.ndarray, np.ndarray], np.ndarray], str]:
"""Set the function called to compute error."""
scale = error_presets[y_axis]["scale"]
error_computers = {
"simple": lambda y_ref, y_lin: scale * (y_ref - y_lin),
"abs": lambda y_ref, y_lin: scale * np.abs(y_ref - y_lin),
"rel": lambda y_ref, y_lin: scale * (y_ref - y_lin) / y_ref,
"log": lambda y_ref, y_lin: scale * np.log10(np.abs(y_lin / y_ref)),
}
key = error_presets[y_axis]["diff"]
fun_error = error_computers[key]
return fun_error
[docs]
def _compute_error(
x_data: list[np.ndarray],
y_data: list[np.ndarray],
plt_kwargs: list[dict[str, Any]],
fun_error: Callable[[np.ndarray, np.ndarray], np.ndarray],
error_reference: str,
) -> tuple[list[np.ndarray], list[np.ndarray], list[dict[str, Any]]]:
"""Compute error with proper reference and proper function."""
simulation_indexes = range(len(x_data))
if error_reference == "ref accelerator (1st solv w/ 1st solv, 2nd w/ 2nd)":
i_ref = [i for i in range(len(x_data) // 2)]
elif error_reference == "ref accelerator (1st solver)":
i_ref = [0]
elif error_reference == "ref accelerator (2nd solver)":
i_ref = [1]
if len(x_data) < 4:
logging.error(
f"{error_reference = } not supported when only one "
"simulation is performed."
)
return np.full((10, 1), np.nan), np.full((10, 1), np.nan), []
else:
logging.error(
f"{error_reference = }, which is not allowed. Check "
"allowed values in _compute_error."
)
return np.full((10, 1), np.nan), np.full((10, 1), np.nan), []
i_err = [i for i in simulation_indexes if i not in i_ref]
indexes_ref_with_err = itertools.zip_longest(
i_ref, i_err, fillvalue=i_ref[0]
)
x_data_error, y_data_error = [], []
for ref, err in indexes_ref_with_err:
x_interp, y_ref, _, y_err = helper.resample(
x_data[ref], y_data[ref], x_data[err], y_data[err]
)
error = fun_error(y_ref, y_err)
x_data_error.append(x_interp)
y_data_error.append(error)
plt_kwargs = [plt_kwargs[i] for i in i_err]
return x_data_error, y_data_error, plt_kwargs