Source code for lightwin.util.solvers

"""Define equation solvers.

For now, used only in acceptance computations.

"""

import logging
from typing import Callable

import numpy as np
from scipy.optimize import brentq


[docs] def solve_scalar_equation_brent( func: Callable[[float, float], float], param_value: float, x_bounds: tuple[float, float], warning: bool = True, ) -> float: """ Solve a scalar equation for multiple parameters using Brent's method. Parameters ---------- func Function f(x, param) whose root is to be found for each parameter. param_value Single parameter value to use when solving the equation. x_bounds Interval (x_min, x_max) in which to search for the root. The bounds will be swapped if provided in reverse order. warning If True (default), warnings are issued for empty inputs, range inversion, or missing roots in the interval. Returns ------- float Root found for the parameter value. NaN if no root is found. """ x_left, x_right = x_bounds if x_left > x_right: x_left, x_right = x_right, x_left if warning: logging.warning( f"The range ({x_bounds[0]}, {x_bounds[1]}) is inverted. " f"It has been corrected to ({x_left}, {x_right})." ) f = lambda x: func(x, param_value) if f(x_left) * f(x_right) > 0: solution = np.nan if warning: logging.warning( f"{f(x_left)} and {f(x_right)} have the same sign in solve_scalar_equation_brent(). " "There is no root in this range" ) return solution try: solution = brentq(f, x_left, x_right) except (ValueError, RuntimeError) as e: solution = np.nan if warning: logging.warning( f"Root finding failed in solve_scalar_equation_brent() with param={param_value}: {e}" ) return solution