"""Define helper function to ease lists manipulation in :mod:`.strategy`.
.. note::
If you are unsure about how a function works, check out the implementation
of the tests in :file:`LightWin/tests/test_failure/test_helper.py`.
"""
import itertools
import math
from collections.abc import Callable, Collection, Sequence
from functools import partial
from typing import Literal
TIE_POLITICS = ("upstream first", "downstream first")
TIE_POLITICS_T = Literal["upstream first", "downstream first"]
[docs]
def _distance_to_ref[T](
element: T,
failed: Sequence[T],
all_elements: Sequence[T],
tie_politics: TIE_POLITICS_T,
shift: int = 0,
) -> tuple[int, int]:
"""Give distance between ``element`` and closest of ``failed``.
Parameters
----------
element :
First object from which you want distance. Often, an :class:`.Element`
of a lattice that will potentially be used for compensation.
failed :
Second object or list of object from which you want distance. Often, a
list of failed :class:`.Element` or a list of lattices with a fault.
all_elements :
All the elements/lattices/sections.
tie_politics :
When two elements have the same position, will you want to have the
upstream or the downstream first?
shift :
Distance increase for downstream elements (``shift < 0``) or upstream
elements (``shift > 0``). Used to have a window of compensating
cavities which is not centered around the failed elements. The default
is 0.
Returns
-------
lowest_distance :
Index-distance between ``element`` and closest element of ``failed``.
Will be used as a primary sorting key.
index :
Index of ``element``. Will be used as a secondary index key, to sort
ties in distance.
"""
index = all_elements.index(element)
distances = (
abs(index - (failure_index := all_elements.index(failed_element)))
+ _penalty(index, failure_index, shift)
for failed_element in failed
)
lowest_distance = min(distances)
if tie_politics == "upstream first":
return lowest_distance, index
if tie_politics == "downstream first":
return lowest_distance, -index
raise OSError(f"{tie_politics = } not understood.")
[docs]
def _penalty(index: int, failure_index: int, shift: int) -> int:
"""Give the distance penalty.
.. note::
If ``shift > 0``, upstream elements are penalized.
If ``shift < 0``, downstream elements are penalized.
"""
if index == failure_index:
return 0
if (failure_index < index) is not (shift < 0):
return 0
return abs(shift)
[docs]
def sort_by_position[T](
all_elements: Sequence[T],
failed: Sequence[T],
tie_politics: TIE_POLITICS_T = "upstream first",
shift: int = 0,
) -> list[T]:
"""Sort given list by how far its elements are from ``elements[idx]``.
We go across every element in ``all_elements`` and get their index-distance
to the closest element of ``failed``. We sort ``all_elements`` by this
distance. When there is a tie, we put the more upstream or the more
downstream cavity first according to ``tie_politics``.
Parameters
----------
failed :
Second object or list of object from which you want distance. Often, a
list of failed :class:`.Element` or a list of lattices with a fault.
all_elements :
All the elements/lattices/sections.
tie_politics :
When two elements have the same position, will you want to have the
upstream or the downstream first?
shift :
Distance increase for downstream elements (``shift < 0``) or upstream
elements (``shift > 0``). Used to have a window of compensating
cavities which is not centered around the failed elements. Useful when
upstream cavities have more important power margins, or when you want
more downstream cavities because a full cryomodule is down.
"""
sorter = partial(
_distance_to_ref,
failed=failed,
all_elements=all_elements,
tie_politics=tie_politics,
shift=shift,
)
return sorted(all_elements, key=lambda element: sorter(element))
[docs]
def remove_lists_with_less_than_n_elements[T](
elements: Sequence[Sequence[T]], minimum_size: int = 1
) -> list[list[T]]:
"""Return a list where objects have a minimum length of ``minimum_size``."""
out = [list(x) for x in elements if len(x) >= minimum_size]
return out
[docs]
def gather[T](
failed_elements: list[T],
fun_sort: Callable[[Sequence[T] | Sequence[Sequence[T]]], list[T]],
) -> tuple[list[list[T]], list[list[T]]]:
"""Gather faults to be fixed together and associated compensating cav.
Parameters
----------
failed_elements :
Holds ungathered failed cavities.
fun_sort :
Takes in a list or a list of list of failed cavities, returns the list
or list of list of altered cavities (failed + compensating).
Returns
-------
failed_gathered :
Failures, gathered by faults that require the same compensating
cavities.
compensating_gathered :
Corresponding compensating cavities.
"""
r_comb = 2
flag_gathered = False
altered_gathered: list[list[T]] = []
failed_gathered = [[failed] for failed in failed_elements]
while not flag_gathered:
# List of list of corresp. compensating cavities
altered_gathered = [
fun_sort(failed_elements=failed) for failed in failed_gathered
]
# Set a counter to exit the 'for' loop when all faults are gathered
i = 0
n_combinations = len(altered_gathered)
if n_combinations <= 1:
flag_gathered = True
break
i_max = int(
math.factorial(n_combinations)
/ (
math.factorial(r_comb)
* math.factorial(n_combinations - r_comb)
)
)
# Now we look every list of required compensating cavities, and
# look for faults that require the same compensating cavities
for (idx1, altered1), (idx2, altered2) in itertools.combinations(
enumerate(altered_gathered), r_comb
):
i += 1
common = list(set(altered1) & set(altered2))
# If at least one cavity on common, gather the two
# corresponding fault and restart the whole process
if len(common) > 0:
failed_gathered[idx1].extend(failed_gathered.pop(idx2))
altered_gathered[idx1].extend(altered_gathered.pop(idx2))
break
# If we reached this point, it means that there is no list of
# faults that share compensating cavities.
if i == i_max:
flag_gathered = True
compensating_gathered = [
list(filter(lambda cavity: cavity not in failed_elements, sublist))
for sublist in altered_gathered
]
return failed_gathered, compensating_gathered
[docs]
def nested_containing_desired[T](
nested: Collection[Sequence[T]],
desired_elements: Collection[T],
) -> list[list[T]]:
"""Return collections of ``nested`` containing some ``desired_elements``.
Example
-------
``nested_containing_desired(ListOfElements.by_lattice, failed_elements)``
will return ``lattices_with_a_failure``
"""
nested_with_desired_elements = [
list(x) for x in nested if not set(desired_elements).isdisjoint(x)
]
return nested_with_desired_elements