"""Define a class to easily generate the :class:`.SimulationOutput`."""
import logging
from abc import ABCMeta
from dataclasses import dataclass
from functools import partial
from pathlib import Path
import numpy as np
import lightwin.util.converters as convert
from lightwin.beam_calculation.simulation_output.factory import (
SimulationOutputFactory,
)
from lightwin.beam_calculation.simulation_output.simulation_output import (
SimulationOutput,
)
from lightwin.beam_calculation.tracewin.beam_parameters_factory import (
BeamParametersFactoryTraceWin,
)
from lightwin.beam_calculation.tracewin.element_tracewin_parameters_factory import (
ElementTraceWinParametersFactory,
)
from lightwin.beam_calculation.tracewin.transfer_matrix_factory import (
TransferMatrixFactoryTraceWin,
)
from lightwin.constants import c
from lightwin.core.list_of_elements.list_of_elements import ListOfElements
from lightwin.core.particle import ParticleFullTrajectory, ParticleInitialState
from lightwin.failures.set_of_cavity_settings import SetOfCavitySettings
[docs]
@dataclass
class SimulationOutputFactoryTraceWin(SimulationOutputFactory):
"""A class for creating simulation outputs for :class:`.TraceWin`."""
out_folder: Path
_filename: Path
beam_calc_parameters_factory: ElementTraceWinParametersFactory
[docs]
def __post_init__(self) -> None:
"""Set filepath-related attributes and create factories.
The created factories are :class:`.TransferMatrixFactory` and
:class:`.BeamParametersFactory`. The sub-class that is used is declared
in :meth:`._transfer_matrix_factory_class` and
:meth:`._beam_parameters_factory_class`.
"""
self.load_results = partial(
_load_results_generic, filename=self._filename
)
# Factories created in ABC's __post_init__
return super().__post_init__()
@property
def _transfer_matrix_factory_class(self) -> ABCMeta:
"""Give the **class** of the transfer matrix factory."""
return TransferMatrixFactoryTraceWin
@property
def _beam_parameters_factory_class(self) -> ABCMeta:
"""Give the **class** of the beam parameters factory."""
return BeamParametersFactoryTraceWin
[docs]
def run(
self,
elts: ListOfElements,
path_cal: Path,
exception: bool,
set_of_cavity_settings: SetOfCavitySettings,
) -> SimulationOutput:
"""
Create an object holding all relatable simulation results.
Parameters
----------
elts : ListOfElements
Contains all elements or only a fraction or all the elements.
path_cal : pathlib.Path
Path to results folder.
exception : bool
Indicates if the run was unsuccessful or not.
Returns
-------
simulation_output : SimulationOutput
Holds all relatable data in a consistent way between the different
:class:`.BeamCalculator` objects.
"""
if exception:
filepath = Path(path_cal, self._filename)
_remove_incomplete_line(filepath)
_add_dummy_data(filepath, elts)
results = self._create_main_results_dictionary(
path_cal, elts.input_particle
)
if exception:
results = _remove_invalid_values(results)
self._save_tracewin_meshing_in_elements(
elts, results["##"], results["z(m)"]
)
synch_trajectory = ParticleFullTrajectory(
w_kin=results["w_kin"],
phi_abs=results["phi_abs"],
synchronous=True,
beam=self._beam_kwargs,
)
cavity_parameters = self._create_cavity_parameters(path_cal, len(elts))
element_to_index = self._generate_element_to_index_func(elts)
transfer_matrix = self.transfer_matrix_factory.run(
elts.tm_cumul_in, path_cal, element_to_index
)
z_abs = results["z(m)"]
gamma_kin = synch_trajectory.get("gamma")
beam_parameters = self.beam_parameters_factory.factory_method(
z_abs, gamma_kin, results, element_to_index
)
simulation_output = SimulationOutput(
out_folder=self.out_folder,
is_multiparticle=hasattr(beam_parameters, "phiw99"),
is_3d=True,
z_abs=results["z(m)"],
synch_trajectory=synch_trajectory,
cav_params=cavity_parameters,
beam_parameters=beam_parameters,
element_to_index=element_to_index,
transfer_matrix=transfer_matrix,
set_of_cavity_settings=set_of_cavity_settings,
)
simulation_output.z_abs = results["z(m)"]
# FIXME attribute was not declared
simulation_output.pow_lost = results["Powlost"]
return simulation_output
[docs]
def _create_main_results_dictionary(
self, path_cal: Path, input_particle: ParticleInitialState
) -> dict[str, np.ndarray]:
"""Load the TraceWin results, compute common interest quantities."""
results = self.load_results(path_cal=path_cal)
results = _set_energy_related_results(results, **self._beam_kwargs)
results = _set_phase_related_results(
results, z_in=input_particle.z_in, phi_in=input_particle.phi_abs
)
return results
# TODO FIXME
[docs]
def _save_tracewin_meshing_in_elements(
self, elts: ListOfElements, elt_numbers: np.ndarray, z_abs: np.ndarray
) -> None:
"""Take output files to determine where are evaluated ``w_kin``..."""
elt_numbers = elt_numbers.astype(int)
for elt_number, elt in enumerate(elts, start=1):
elt_mesh_indexes = np.where(elt_numbers == elt_number)[0]
s_in = elt_mesh_indexes[0] - 1
s_out = elt_mesh_indexes[-1]
z_element = z_abs[s_in : s_out + 1]
elt.beam_calc_param[self._solver_id] = (
self.beam_calc_parameters_factory.run(
elt, z_element, s_in, s_out
)
)
[docs]
def _create_cavity_parameters(
self,
path_cal: Path,
n_elts: int,
filename: Path = Path("Cav_set_point_res.dat"),
) -> dict[str, list[float | None]]:
"""Load and format a dict containing v_cav and phi_s.
It has the same format as :class:`.Envelope1D` solver format.
Parameters
----------
path_cal : pathlib.Path
Path to the folder where the cavity parameters file is stored.
n_elts : int
Number of elements under study.
filename : pathlib.Path, optional
The name of the cavity parameters file produced by TraceWin. The
default is Path('Cav_set_point_res.dat').
Returns
-------
cavity_param : dict[str, list[float | None]]
Contains the cavity parameters. Keys are ``'v_cav_mv'`` and
``'phi_s'``.
"""
cavity_parameters = _load_cavity_parameters(path_cal, filename)
cavity_parameters = _cavity_parameters_uniform_with_envelope1d(
cavity_parameters, n_elts
)
return cavity_parameters
# =============================================================================
# Main `results` dictionary
# =============================================================================
[docs]
def _0_to_NaN(data: np.ndarray) -> np.ndarray:
"""Replace 0 by np.nan in given array."""
data[np.where(data == 0.0)] = np.nan
return data
[docs]
def _remove_invalid_values(
results: dict[str, np.ndarray],
) -> dict[str, np.ndarray]:
"""Remove invalid values that appear when ``exception`` is True."""
results["SizeX"] = _0_to_NaN(results["SizeX"])
results["SizeY"] = _0_to_NaN(results["SizeY"])
results["SizeZ"] = _0_to_NaN(results["SizeZ"])
return results
[docs]
def _load_results_generic(
filename: Path, path_cal: Path
) -> dict[str, np.ndarray]:
"""Load the TraceWin results.
This function is not called directly. Instead, every instance of
:class:`.TraceWin` object has a `load_results` method which calls this
function with a default ``filename`` argument.
The value of ``filename`` depends on the TraceWin simulation that was run:
multiparticle or envelope.
Parameters
----------
filename : pathlib.Path
Results file produced by TraceWin.
path_cal : pathlib.Path
Folder where the results file is located.
Returns
-------
results : dict[str, numpy.ndarray]
Dictionary containing the raw outputs from TraceWin.
"""
f_p = Path(path_cal, filename)
n_lines_header = 9
results = {}
with open(f_p, encoding="utf-8") as file:
for i, line in enumerate(file):
if i == 1:
__mc2, freq, __z, __i, __npart = line.strip().split()
if i == n_lines_header:
headers = line.strip().split()
break
results["freq"] = float(freq)
out = np.loadtxt(f_p, skiprows=n_lines_header)
for i, key in enumerate(headers):
results[key] = out[:, i]
logging.debug(f"successfully loaded {f_p}")
return results
# =============================================================================
# Handle errors
# =============================================================================
[docs]
def _remove_incomplete_line(filepath: Path) -> None:
"""
Remove incomplete line from ``.out`` file.
.. todo::
fix possible unbound error for ``n_columns``.
"""
n_lines_header = 9
i_last_valid = -1
with open(filepath, encoding="utf-8") as file:
lines = file.readlines()
for i, line in enumerate(lines):
if i < n_lines_header:
continue
if i == n_lines_header:
n_columns = len(line.split())
if len(line.split()) != n_columns:
i_last_valid = i
break
if i_last_valid == -1:
return
logging.warning(
f"Not enough columns in `.out` after line {i_last_valid}. "
"Removing all lines after this one..."
)
with open(filepath, "w", encoding="utf-8") as file:
for i, line in enumerate(lines):
if i >= i_last_valid:
break
file.write(line)
[docs]
def _add_dummy_data(filepath: Path, elts: ListOfElements) -> None:
"""
Add dummy data at the end of the ``.out`` to reach end of linac.
We also round the column 'z', to avoid a too big mismatch between the z
column and what we should have.
.. todo::
another possibly unbound error to handle
"""
with open(filepath, "r+", encoding="utf-8") as file:
for line in file:
pass
last_idx_in_file = int(line.split()[0])
last_element_in_file = elts[last_idx_in_file - 1]
if last_element_in_file is not elts[-1]:
logging.warning(
"Incomplete `.out` file. Trying to complete with "
"dummy data..."
)
elts_to_add = elts[last_idx_in_file:]
last_pos = np.round(float(line.split()[1]), 4)
for i, elt in enumerate(elts_to_add, start=last_idx_in_file + 1):
last_pos += elt.get("length_m", to_numpy=False)
new_line = line.split()
new_line[0] = str(i)
new_line[1] = str(last_pos)
new_line = " ".join(new_line) + "\n"
file.write(new_line)
# =============================================================================
# Cavity parameters
# =============================================================================
[docs]
def _load_cavity_parameters(
path_cal: Path, filename: Path
) -> dict[str, np.ndarray]:
"""
Get the cavity parameters calculated by TraceWin.
Parameters
----------
path_cal : pathlib.Path
Path to the folder where the cavity parameters file is stored.
filename : pathlib.Path
The name of the cavity parameters file produced by TraceWin.
Returns
-------
cavity_param : dict[float, numpy.ndarray]
Contains the cavity parameters.
"""
f_p = Path(path_cal, filename)
n_lines_header = 1
with open(f_p, encoding="utf-8") as file:
for i, line in enumerate(file):
if i == n_lines_header - 1:
headers = line.strip().split()
break
out = np.loadtxt(f_p, skiprows=n_lines_header)
cavity_parameters = {key: out[:, i] for i, key in enumerate(headers)}
logging.debug(f"successfully loaded {f_p}")
return cavity_parameters