Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: Add graceful handling of invalid initial brackets to elementwise bracket finding functions #20685

Merged
merged 9 commits into from
May 13, 2024
1 change: 1 addition & 0 deletions scipy/_lib/_elementwise_iterative_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_ECONVERR = -2
_EVALUEERR = -3
_ECALLBACK = -4
_EINPUTERR = -5
_ECONVERGED = 0
_EINPROGRESS = 1

Expand Down
49 changes: 32 additions & 17 deletions scipy/optimize/_bracket.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ def _bracket_root_iv(func, xl0, xr0, xmin, xmax, factor, args, maxiter):
if not maxiter == maxiter_int or maxiter < 0:
raise ValueError(message)

if not np.all((xmin <= xl0) & (xl0 < xr0) & (xr0 <= xmax)):
raise ValueError('`xmin <= xl0 < xr0 <= xmax` must be True (elementwise).')

return func, xl0, xr0, xmin, xmax, factor, args, maxiter


Expand Down Expand Up @@ -116,6 +113,7 @@ def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None,
- ``-2`` : The maximum number of iterations was reached.
- ``-3`` : A non-finite value was encountered.
- ``-4`` : Iteration was terminated by `callback`.
- ``-5``: Initial bracket was invalid.
- ``1`` : The algorithm is proceeding normally (in `callback` only).
- ``2`` : A bracket was found in the opposite search direction (in `callback` only).

Expand Down Expand Up @@ -179,9 +177,19 @@ def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None,
# We don't need to retain the corresponding function value, since the
# fixed end of the bracket is only needed to compute the new value of the
# moving end; it is never returned.
xmin = np.broadcast_to(xmin, shape).astype(dtype, copy=False)
xmax = np.broadcast_to(xmax, shape).astype(dtype, copy=False)

xmin = np.broadcast_to(xmin, shape).astype(dtype, copy=False).ravel()
xmax = np.broadcast_to(xmax, shape).astype(dtype, copy=False).ravel()
status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress
# Stop with error code when initial bracket is invalid.
i = ~((xmin <= xl0) & (xl0 < xr0) & (xr0 <= xmax))
if i.ndim == 0:
i = np.array([i, i])
else:
i = np.concatenate((i, i))
mdhaber marked this conversation as resolved.
Show resolved Hide resolved

xmin, xmax, i = xmin.ravel(), xmax.ravel(), i.ravel()
status[i] = eim._EINPUTERR
limit = np.concatenate((xmin, xmax))

factor = np.broadcast_to(factor, shape).astype(dtype, copy=False).ravel()
Expand All @@ -205,7 +213,8 @@ def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None,
d[i] = x[i] - x0[i]
d[ni] = limit[ni] - x[ni]

status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress

# We'll terminate with error status when initial bracket is invalid.
mdhaber marked this conversation as resolved.
Show resolved Hide resolved
nit, nfev = 0, 1 # one function evaluation per side performed above

work = _RichResult(x=x, x0=x0, f=f, limit=limit, factor=factor,
Expand Down Expand Up @@ -246,11 +255,13 @@ def post_func_eval(x, f, work):

def check_termination(work):
stop = np.zeros_like(work.x, dtype=bool)
# Condition 0: initial bracket is invalid
stop[work.status == eim._EINPUTERR] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A special case that I don't really want to deal with is xl0 == xr0 == true_root : ( Maybe we should just add a note about it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't addressed this yet. I'd rather not deal with that case either. It looks like the documentation currently doesn't even mention that xmin <= xl0 < xm0 < xr0 <= xmax is needed. I think just adding a note about that should be enough; the behavior will be exactly as documented.

Copy link
Contributor Author

@steppi steppi May 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just put things like this

 - ``-5`` : The initial bracket does not satisfy
              `xmin <= xl0 < xm0 < xr0 <= xmax`.

when documenting the error codes. Do you think that's sufficient, or should this go somewhere else as well? I was thinking of putting it in the Parameters documentation, but wasn't sure where to include it, since it pertains to every one of these bracket parameters. I guess it could go in the xm0 section, stating that the parameters are defined below; or in the xmin, xmax section.


# Condition 1: a valid bracket (or the root itself) has been found
sf = np.sign(work.f)
sf_last = np.sign(work.f_last)
i = (sf_last == -sf) | (sf_last == 0) | (sf == 0)
i = ((sf_last == -sf) | (sf_last == 0) | (sf == 0)) & ~stop
work.status[i] = eim._ECONVERGED
stop[i] = True

Expand Down Expand Up @@ -449,11 +460,6 @@ def _bracket_minimum_iv(func, xm0, xl0, xr0, xmin, xmax, factor, args, maxiter):
if not maxiter == maxiter_int or maxiter < 0:
raise ValueError(message)

if not np.all((xmin <= xl0) & (xl0 < xm0) & (xm0 < xr0) & (xr0 <= xmax)):
raise ValueError(
'`xmin <= xl0 < xm0 < xr0 <= xmax` must be True (elementwise).'
)

return func, xm0, xl0, xr0, xmin, xmax, factor, args, maxiter


Expand Down Expand Up @@ -524,6 +530,7 @@ def _bracket_minimum(func, xm0, *, xl0=None, xr0=None, xmin=None, xmax=None,
minimizer.
- ``-2`` : The maximum number of iterations was reached.
- ``-3`` : A non-finite value was encountered.
- ``-5``: Initial bracket was invalid.
steppi marked this conversation as resolved.
Show resolved Hide resolved

success : bool
``True`` when the algorithm terminated successfully (status ``0``).
Expand Down Expand Up @@ -572,6 +579,11 @@ def _bracket_minimum(func, xm0, *, xl0=None, xr0=None, xmin=None, xmax=None,
# a read-only view.
factor = np.broadcast_to(factor, shape).astype(dtype, copy=True).ravel()

# Stop with error code when initial bracket is invalid.
status = np.full_like(xl0, eim._EINPROGRESS, dtype=int)
i = ~((xmin <= xl0) & (xl0 < xm0) & (xm0 < xr0) & (xr0 <= xmax))
status[i] = eim._EINPUTERR

# To simplify the logic, swap xl and xr if f(xl) < f(xr). We should always be
# marching downhill in the direction from xl to xr.
comp = fl0 < fr0
Expand All @@ -589,8 +601,6 @@ def _bracket_minimum(func, xm0, *, xl0=None, xr0=None, xmin=None, xmax=None,

# Step size is divided by factor for case where there is a limit.
factor[limited] = 1 / factor[limited]

status = np.full_like(xl0, eim._EINPROGRESS, dtype=int)
nit, nfev = 0, 3

work = _RichResult(xl=xl0, xm=xm0, xr=xr0, xr0=xr0, fl=fl0, fm=fm0, fr=fr0,
Expand Down Expand Up @@ -622,12 +632,17 @@ def post_func_eval(x, f, work):
work.fl, work.fm, work.fr = work.fm, work.fr, f

def check_termination(work):
stop = np.zeros_like(work.xr, dtype=bool)
# Condition 0: Initial bracket is invalid.
stop[work.status == eim._EINPUTERR] = True
mdhaber marked this conversation as resolved.
Show resolved Hide resolved

# Condition 1: A valid bracket has been found.
stop = (
i = (
(work.fl >= work.fm) & (work.fr > work.fm)
| (work.fl > work.fm) & (work.fr >= work.fm)
)
work.status[stop] = eim._ECONVERGED
) & ~stop
work.status[i] = eim._ECONVERGED
stop[i] = True

# Condition 2: Moving end of bracket reaches limit.
i = (work.xr == work.limit) & ~stop
Expand Down
68 changes: 34 additions & 34 deletions scipy/optimize/tests/test_bracket.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,29 @@ def f(xs, js):
funcs = [lambda x: x - 1.5,
lambda x: x - 1000,
lambda x: x - 1000,
lambda x: np.nan]
lambda x: np.nan,
lambda x: x,
lambda x: x,
lambda x: x]
Copy link
Contributor

@mdhaber mdhaber May 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we only need to have one example of each status code here; I'd prefer that rather than having one example of the first few code and many examples of the last code. My intent when writing this test was to show that the codes could all be different; it really is elementwise. There could be a separate function to test the logic for generating this status code with many different cases, if desired (i.e. all possible invalid orderings). But maybe the most natural example here would be to have everything in reverse order, xmin > xl0 > xr0 > xmax.


return [funcs[j](x) for x, j in zip(xs, js)]

args = (np.arange(4, dtype=np.int64),)
res = _bracket_root(f, xl0=[-1, -1, -1, -1], xr0=[1, 1, 1, 1],
xmin=[-np.inf, -1, -np.inf, -np.inf],
xmax=[np.inf, 1, np.inf, np.inf],
args = (np.arange(7, dtype=np.int64),)
res = _bracket_root(f,
xl0=[-1, -1, -1, -1, 4, -4, -4],
xr0=[1, 1, 1, 1, -4, 4, 4],
xmin=[-np.inf, -1, -np.inf, -np.inf, -np.inf, -np.inf, 10],
xmax=[np.inf, 1, np.inf, np.inf, np.inf, np.nan, np.inf],
args=args, maxiter=3)

ref_flags = np.array([eim._ECONVERGED,
_ELIMITS,
eim._ECONVERR,
eim._EVALUEERR])
eim._EVALUEERR,
eim._EINPUTERR,
eim._EINPUTERR,
eim._EINPUTERR])

assert_equal(res.status, ref_flags)

@pytest.mark.parametrize("root", (0.622, [0.622, 0.623]))
Expand Down Expand Up @@ -213,14 +222,6 @@ def test_input_validation(self):
with pytest.raises(ValueError, match=message):
_bracket_root(lambda x: x, -4, 4, factor=0.5)

message = '`xmin <= xl0 < xr0 <= xmax` must be True'
with pytest.raises(ValueError, match=message):
_bracket_root(lambda x: x, 4, -4)
with pytest.raises(ValueError, match=message):
_bracket_root(lambda x: x, -4, 4, xmax=np.nan)
with pytest.raises(ValueError, match=message):
_bracket_root(lambda x: x, -4, 4, xmin=10)

message = "shape mismatch: objects cannot be broadcast"
# raised by `np.broadcast, but the traceback is readable IMO
with pytest.raises(ValueError, match=message):
Expand Down Expand Up @@ -413,17 +414,30 @@ def f(xs, js):
funcs = [lambda x: (x - 1.5)**2,
lambda x: x,
lambda x: x,
lambda x: np.nan]
lambda x: np.nan,
lambda x: x**2,
lambda x: x**2,
lambda x: x**2,
lambda x: x**2,
lambda x: x**2,
lambda x: x**2]

return [funcs[j](x) for x, j in zip(xs, js)]

args = (np.arange(4, dtype=np.int64),)
xl0, xm0, xr0 = np.full(4, -1.0), np.full(4, 0.0), np.full(4, 1.0)
result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0,
xmin=[-np.inf, -1.0, -np.inf, -np.inf],
args = (np.arange(10, dtype=np.int64),)
xl0 = [-1.0, -1.0, -1.0, -1.0, 6.0, -3.0, -3.0, -6.0, -np.nan, 10.0]
xm0 = [0.0, 0.0, 0.0, 0.0, 4.0, -4.0, -4.0, -4.0, -4.0, -4.0]
xr0 = [1.0, 1.0, 1.0, 1.0, 7.0, -6.0, -2.0, -5.0, 4.0, np.nan]
xmin=[-np.inf, -1.0, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf,
-np.inf, -np.inf, -np.inf]
result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, xmin=xmin,
args=args, maxiter=3)

reference_flags = np.array([eim._ECONVERGED, _ELIMITS,
eim._ECONVERR, eim._EVALUEERR])
eim._ECONVERR, eim._EVALUEERR,
eim._EINPUTERR, eim._EINPUTERR,
eim._EINPUTERR, eim._EINPUTERR,
eim._EINPUTERR, eim._EINPUTERR])
assert_equal(result.status, reference_flags)

@pytest.mark.parametrize("minimum", (0.622, [0.622, 0.623]))
Expand Down Expand Up @@ -469,20 +483,6 @@ def test_input_validation(self):
with pytest.raises(ValueError, match=message):
_bracket_minimum(lambda x: x, -4, factor=0.5)

message = '`xmin <= xl0 < xm0 < xr0 <= xmax` must be True'
with pytest.raises(ValueError, match=message):
_bracket_minimum(lambda x: x**2, 4, xl0=6)
with pytest.raises(ValueError, match=message):
_bracket_minimum(lambda x: x**2, -4, xr0=-6)
with pytest.raises(ValueError, match=message):
_bracket_minimum(lambda x: x**2, -4, xl0=-3, xr0=-2)
with pytest.raises(ValueError, match=message):
_bracket_minimum(lambda x: x**2, -4, xl0=-6, xr0=-5)
with pytest.raises(ValueError, match=message):
_bracket_minimum(lambda x: x**2, -4, xl0=-np.nan)
with pytest.raises(ValueError, match=message):
_bracket_minimum(lambda x: x**2, -4, xr0=np.nan)

message = "shape mismatch: objects cannot be broadcast"
# raised by `np.broadcast, but the traceback is readable IMO
with pytest.raises(ValueError, match=message):
Expand Down