Source code for lightwin.visualization.helper

"""Define types and helpers for the visualization library."""

import logging
from collections.abc import Sequence
from pathlib import Path
from typing import Literal

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure, SubFigure

X_AXIS_T = Literal["z_abs", "elt_idx"]


[docs] def create_fig_if_not_exists( axnum: int | list[int] | range, sharex: bool = False, num: int = 1, clean_fig: bool = False, **kwargs, ) -> tuple[Figure, list[Axes]]: """ Check if figures were already created, create it if not. Parameters ---------- axnum : int | list[int] | range Axes indexes as understood by ``Figure.add_subplot``, or number of desired axes. sharex : bool, optional If x axis should be shared. The default is False. num : int, optional Fig number. The default is 1. clean_fig: bool, optional If the previous plot should be erased from Figure. The default is False. """ if isinstance(axnum, int): # We make a one-column, `axnum` rows figure axnum = range(100 * axnum + 11, 101 * axnum + 11) if plt.fignum_exists(num): fig = plt.figure(num) axlist = fig.get_axes() if clean_fig: clean_figures([num]) return fig, axlist fig = plt.figure(num) axlist = [fig.add_subplot(axnum[0])] shared_ax = None if sharex: shared_ax = axlist[0] axlist += [fig.add_subplot(num, sharex=shared_ax) for num in axnum[1:]] return fig, axlist
[docs] def clean_figures(fig_ids: Sequence[int | str | Figure | SubFigure]) -> None: """Clean axis of Figs in fignumlist.""" for fig_id in fig_ids: fig = plt.figure(fig_id) clean_axes(fig.get_axes())
[docs] def clean_axes(ax_ids: Sequence[Axes]) -> None: """Clean given axis.""" for ax in ax_ids: ax.cla()
[docs] def remove_artists(axe: Axes) -> None: """Remove lines and plots, but keep labels and grids.""" for artist in axe.lines: artist.remove() axe.set_prop_cycle(None) # type: ignore
[docs] def savefig(fig: Figure, filepath: Path) -> None: """Save the figure.""" fig.set_size_inches(25.6, 13.64) fig.tight_layout() fig.savefig(filepath) logging.debug(f"Fig. saved in {filepath}")