Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: optimize._chandrupatla: add array API support #20689

Merged
merged 13 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 25 additions & 3 deletions scipy/_lib/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def is_complex(x: Array, xp: ModuleType) -> bool:
def xp_minimum(x1, x2):
# xp won't be passed in because it doesn't need to be passed in to xp.minimum
xp = array_namespace(x1, x2)
if hasattr(xp, 'minimum'):
return xp.minimum(x1, x2)
x1, x2 = xp.broadcast_arrays(x1, x2)
dtype = xp.result_type(x1.dtype, x2.dtype)
res = xp.asarray(x1, copy=True, dtype=dtype)
Expand All @@ -404,10 +406,16 @@ def xp_minimum(x1, x2):
# temporary substitute for xp.clip, which is not yet in all backends
# or covered by array_api_compat.
def xp_clip(x, a, b, xp=None):
xp = array_namespace(xp) if xp is None else xp
xp = array_namespace(x) if xp is None else xp
a, b = xp.asarray(a, dtype=x.dtype), xp.asarray(b, dtype=x.dtype)
if hasattr(xp, 'clip'):
return xp.clip(x, a, b)
x, a, b = xp.broadcast_arrays(x, a, b)
y = xp.asarray(x, copy=True)
y[y < a] = a
y[y > b] = b
ia = y < a
y[ia] = a[ia]
ib = y > b
y[ib] = b[ib]
return y[()] if y.ndim == 0 else y


Expand All @@ -428,3 +436,17 @@ def xp_copysign(x1, x2, xp=None):
xp = array_namespace(x1, x2) if xp is None else xp
abs_x1 = xp.abs(x1)
return xp.where(x2 >= 0, abs_x1, -abs_x1)


# partial substitute for xp.sign, which does not cover the NaN special case
# that I need. (https://github.com/data-apis/array-api-compat/issues/136)
def xp_sign(x, xp=None):
xp = array_namespace(x) if xp is None else xp
if is_numpy(xp): # only NumPy implements the special cases correctly
return xp.sign(x)
sign = xp.full_like(x, xp.nan)
tupui marked this conversation as resolved.
Show resolved Hide resolved
one = xp.asarray(1, dtype=x.dtype)
sign = xp.where(x > 0, one, sign)
sign = xp.where(x < 0, -one, sign)
sign = xp.where(x == 0, 0*one, sign)
return sign
133 changes: 80 additions & 53 deletions scipy/_lib/_elementwise_iterative_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# `scipy.optimize._bracket._bracket_minimize for finding minimization brackets,
# `scipy.integrate._tanhsinh._tanhsinh` for numerical quadrature.

import math
import numpy as np
from ._util import _RichResult, _call_callback_maybe_halt
from ._array_api import array_namespace, size as xp_size

_ESIGNERR = -1
_ECONVERR = -2
Expand Down Expand Up @@ -70,18 +72,19 @@ def _initialize(func, xs, args, complex_ok=False, preserve_shape=None):
`scipy.optimize._chandrupatla`.
"""
nx = len(xs)
xp = array_namespace(*xs)

# Try to preserve `dtype`, but we need to ensure that the arguments are at
# least floats before passing them into the function; integers can overflow
# and cause failure.
# There might be benefit to combining the `xs` into a single array and
# calling `func` once on the combined array. For now, keep them separate.
xas = np.broadcast_arrays(*xs, *args) # broadcast and rename
xat = np.result_type(*[xa.dtype for xa in xas])
xat = np.float64 if np.issubdtype(xat, np.integer) else xat
xas = xp.broadcast_arrays(*xs, *args) # broadcast and rename
xat = xp.result_type(*[xa.dtype for xa in xas])
xat = xp.asarray(1.).dtype if xp.isdtype(xat, "integral") else xat
xs, args = xas[:nx], xas[nx:]
xs = [x.astype(xat, copy=False)[()] for x in xs]
fs = [np.asarray(func(x, *args)) for x in xs]
xs = [xp.asarray(x, dtype=xat) for x in xs] # use copy=False when implemented
fs = [xp.asarray(func(x, *args)) for x in xs]
shape = xs[0].shape
fshape = fs[0].shape

Expand All @@ -90,38 +93,38 @@ def _initialize(func, xs, args, complex_ok=False, preserve_shape=None):
def func(x, *args, shape=shape, func=func, **kwargs):
i = (0,)*(len(fshape) - len(shape))
return func(x[i], *args, **kwargs)
shape = np.broadcast_shapes(fshape, shape)
xs = [np.broadcast_to(x, shape) for x in xs]
args = [np.broadcast_to(arg, shape) for arg in args]
shape = np.broadcast_shapes(fshape, shape) # just shapes; use of NumPy OK
xs = [xp.broadcast_to(x, shape) for x in xs]
args = [xp.broadcast_to(arg, shape) for arg in args]

message = ("The shape of the array returned by `func` must be the same as "
"the broadcasted shape of `x` and all other `args`.")
if preserve_shape is not None: # only in tanhsinh for now
message = f"When `preserve_shape=False`, {message.lower()}"
shapes_equal = [f.shape == shape for f in fs]
if not np.all(shapes_equal):
if not all(shapes_equal): # use Python all to reduce overhead
raise ValueError(message)

# These algorithms tend to mix the dtypes of the abscissae and function
# values, so figure out what the result will be and convert them all to
# that type from the outset.
xfat = np.result_type(*([f.dtype for f in fs] + [xat]))
if not complex_ok and not np.issubdtype(xfat, np.floating):
xfat = xp.result_type(*([f.dtype for f in fs] + [xat]))
if not complex_ok and not xp.isdtype(xfat, "real floating"):
raise ValueError("Abscissae and function output must be real numbers.")
xs = [x.astype(xfat, copy=True)[()] for x in xs]
fs = [f.astype(xfat, copy=True)[()] for f in fs]
xs = [xp.asarray(x, dtype=xfat, copy=True) for x in xs]
fs = [xp.asarray(f, dtype=xfat, copy=True) for f in fs]

# To ensure that we can do indexing, we'll work with at least 1d arrays,
# but remember the appropriate shape of the output.
xs = [x.ravel() for x in xs]
fs = [f.ravel() for f in fs]
args = [arg.flatten() for arg in args]
return func, xs, fs, args, shape, xfat
xs = [xp.reshape(x, (-1,)) for x in xs]
fs = [xp.reshape(f, (-1,)) for f in fs]
args = [xp.reshape(xp.asarray(arg, copy=True), (-1,)) for arg in args]
tupui marked this conversation as resolved.
Show resolved Hide resolved
return func, xs, fs, args, shape, xfat, xp


def _loop(work, callback, shape, maxiter, func, args, dtype, pre_func_eval,
post_func_eval, check_termination, post_termination_check,
customize_result, res_work_pairs, preserve_shape=False):
customize_result, res_work_pairs, xp, preserve_shape=False):
"""Main loop of a vectorized scalar optimization algorithm

Parameters
Expand Down Expand Up @@ -186,82 +189,88 @@ def _loop(work, callback, shape, maxiter, func, args, dtype, pre_func_eval,
computation on elements that have already converged.

"""
if xp is None:
raise NotImplementedError("Must provide xp.")

cb_terminate = False

# Initialize the result object and active element index array
n_elements = int(np.prod(shape))
active = np.arange(n_elements) # in-progress element indices
res_dict = {i: np.zeros(n_elements, dtype=dtype) for i, j in res_work_pairs}
res_dict['success'] = np.zeros(n_elements, dtype=bool)
res_dict['status'] = np.full(n_elements, _EINPROGRESS)
res_dict['nit'] = np.zeros(n_elements, dtype=int)
res_dict['nfev'] = np.zeros(n_elements, dtype=int)
n_elements = math.prod(shape)
active = xp.arange(n_elements) # in-progress element indices
res_dict = {i: xp.zeros(n_elements, dtype=dtype) for i, j in res_work_pairs}
res_dict['success'] = xp.zeros(n_elements, dtype=xp.bool)
res_dict['status'] = xp.full(n_elements, _EINPROGRESS, dtype=xp.int32)
res_dict['nit'] = xp.zeros(n_elements, dtype=xp.int32)
res_dict['nfev'] = xp.zeros(n_elements, dtype=xp.int32)
res = _RichResult(res_dict)
work.args = args

active = _check_termination(work, res, res_work_pairs, active,
check_termination, preserve_shape)
check_termination, preserve_shape, xp)

if callback is not None:
temp = _prepare_result(work, res, res_work_pairs, active, shape,
customize_result, preserve_shape)
customize_result, preserve_shape, xp)
if _call_callback_maybe_halt(callback, temp):
cb_terminate = True

while work.nit < maxiter and active.size and not cb_terminate and n_elements:
while work.nit < maxiter and xp_size(active) and not cb_terminate and n_elements:
x = pre_func_eval(work)

if work.args and work.args[0].ndim != x.ndim:
# `x` always starts as 1D. If the SciPy function that uses
# _loop added dimensions to `x`, we need to
# add them to the elements of `args`.
dims = np.arange(x.ndim, dtype=np.int64)
work.args = [np.expand_dims(arg, tuple(dims[arg.ndim:]))
for arg in work.args]
args = []
for arg in work.args:
n_new_dims = x.ndim - arg.ndim
new_shape = arg.shape + (1,)*n_new_dims
args.append(xp.reshape(arg, new_shape))
work.args = args

x_shape = x.shape
if preserve_shape:
x = x.reshape(shape + (-1,))
x = xp.reshape(x, (shape + (-1,)))
f = func(x, *work.args)
f = np.asarray(f, dtype=dtype)
f = xp.asarray(f, dtype=dtype)
if preserve_shape:
x = x.reshape(x_shape)
f = f.reshape(x_shape)
x = xp.reshape(x, x_shape)
f = xp.reshape(f, x_shape)
work.nfev += 1 if x.ndim == 1 else x.shape[-1]

post_func_eval(x, f, work)

work.nit += 1
active = _check_termination(work, res, res_work_pairs, active,
check_termination, preserve_shape)
check_termination, preserve_shape, xp)

if callback is not None:
temp = _prepare_result(work, res, res_work_pairs, active, shape,
customize_result, preserve_shape)
customize_result, preserve_shape, xp)
if _call_callback_maybe_halt(callback, temp):
cb_terminate = True
break
if active.size == 0:
if xp_size(active) == 0:
break

post_termination_check(work)

work.status[:] = _ECALLBACK if cb_terminate else _ECONVERR
return _prepare_result(work, res, res_work_pairs, active, shape,
customize_result, preserve_shape)
customize_result, preserve_shape, xp)


def _check_termination(work, res, res_work_pairs, active, check_termination,
preserve_shape):
preserve_shape, xp):
# Checks termination conditions, updates elements of `res` with
# corresponding elements of `work`, and compresses `work`.

stop = check_termination(work)

if np.any(stop):
if xp.any(stop):
# update the active elements of the result object with the active
# elements for which a termination condition has been met
_update_active(work, res, res_work_pairs, active, stop, preserve_shape)
_update_active(work, res, res_work_pairs, active, stop, preserve_shape, xp)

if preserve_shape:
stop = stop[active]
Expand All @@ -272,13 +281,20 @@ def _check_termination(work, res, res_work_pairs, active, check_termination,
if not preserve_shape:
# compress the arrays to avoid unnecessary computation
for key, val in work.items():
work[key] = val[proceed] if isinstance(val, np.ndarray) else val
# Need to find a better way than these try/excepts
# Somehow need to keep compressible numerical args separate
if key == 'args':
continue
try:
work[key] = val[proceed]
except (IndexError, TypeError, KeyError): # not a compressible array
work[key] = val
work.args = [arg[proceed] for arg in work.args]

return active


def _update_active(work, res, res_work_pairs, active, mask, preserve_shape):
def _update_active(work, res, res_work_pairs, active, mask, preserve_shape, xp):
# Update `active` indices of the arrays in result object `res` with the
# contents of the scalars and arrays in `update_dict`. When provided,
# `mask` is a boolean array applied both to the arrays in `update_dict`
Expand All @@ -288,34 +304,45 @@ def _update_active(work, res, res_work_pairs, active, mask, preserve_shape):

if mask is not None:
if preserve_shape:
active_mask = np.zeros_like(mask)
active_mask = xp.zeros_like(mask)
active_mask[active] = 1
active_mask = active_mask & mask
for key, val in update_dict.items():
res[key][active_mask] = (val[active_mask] if np.size(val) > 1
else val)
try:
res[key][active_mask] = val[active_mask]
except (IndexError, TypeError, KeyError):
res[key][active_mask] = val
else:
active_mask = active[mask]
for key, val in update_dict.items():
res[key][active_mask] = val[mask] if np.size(val) > 1 else val
try:
res[key][active_mask] = val[mask]
except (IndexError, TypeError, KeyError):
res[key][active_mask] = val
else:
for key, val in update_dict.items():
if preserve_shape and not np.isscalar(val):
val = val[active]
if preserve_shape:
try:
val = val[active]
except (IndexError, TypeError, KeyError):
pass
res[key][active] = val


def _prepare_result(work, res, res_work_pairs, active, shape, customize_result,
preserve_shape):
preserve_shape, xp):
# Prepare the result object `res` by creating a copy, copying the latest
# data from work, running the provided result customization function,
# and reshaping the data to the original shapes.
res = res.copy()
_update_active(work, res, res_work_pairs, active, None, preserve_shape)
_update_active(work, res, res_work_pairs, active, None, preserve_shape, xp)

shape = customize_result(res, shape)

for key, val in res.items():
res[key] = np.reshape(val, shape)[()]
# this looks like it won't work for xp != np if val is not numeric
temp = xp.reshape(val, shape)
res[key] = temp[()] if temp.ndim == 0 else temp

res['_order_keys'] = ['success'] + [i for i, j in res_work_pairs]
return _RichResult(**res)
6 changes: 3 additions & 3 deletions scipy/integrate/_tanhsinh.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def _tanhsinh(f, a, b, *, args=(), log=False, maxfun=None, maxlevel=None,
c[inf_a & inf_b] = 0 # takes care of infinite a and b
temp = eim._initialize(f, (c,), args, complex_ok=True,
preserve_shape=preserve_shape)
f, xs, fs, args, shape, dtype = temp
f, xs, fs, args, shape, dtype, xp = temp
a = np.broadcast_to(a, shape).astype(dtype).ravel()
b = np.broadcast_to(b, shape).astype(dtype).ravel()

Expand Down Expand Up @@ -461,7 +461,7 @@ def customize_result(res, shape):
with np.errstate(over='ignore', invalid='ignore', divide='ignore'):
res = eim._loop(work, callback, shape, maxiter, f, args, dtype, pre_func_eval,
post_func_eval, check_termination, post_termination_check,
customize_result, res_work_pairs, preserve_shape)
customize_result, res_work_pairs, xp, preserve_shape)
return res


Expand Down Expand Up @@ -1068,7 +1068,7 @@ def _nsum(f, a, b, step=1, args=(), log=False, maxterms=int(2**20), atol=None,

# Additional elementwise algorithm input validation / standardization
tmp = eim._initialize(f, (a,), args, complex_ok=False)
f, xs, fs, args, shape, dtype = tmp
f, xs, fs, args, shape, dtype, xp = tmp

# Finish preparing `a`, `b`, and `step` arrays
a = xs[0]
Expand Down
10 changes: 6 additions & 4 deletions scipy/optimize/_bracket.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None,

xs = (xl0, xr0)
temp = eim._initialize(func, xs, args)
func, xs, fs, args, shape, dtype = temp # line split for PEP8
func, xs, fs, args, shape, dtype, _ = temp # line split for PEP8
mdhaber marked this conversation as resolved.
Show resolved Hide resolved
xl0, xr0 = xs
xmin = np.broadcast_to(xmin, shape).astype(dtype, copy=False).ravel()
xmax = np.broadcast_to(xmax, shape).astype(dtype, copy=False).ravel()
Expand Down Expand Up @@ -381,7 +381,8 @@ def customize_result(res, shape):

return eim._loop(work, callback, shape, maxiter, func, args, dtype,
pre_func_eval, post_func_eval, check_termination,
post_termination_check, customize_result, res_work_pairs)
post_termination_check, customize_result, res_work_pairs,
xp)


def _bracket_minimum_iv(func, xm0, xl0, xr0, xmin, xmax, factor, args, maxiter):
Expand Down Expand Up @@ -562,7 +563,8 @@ def _bracket_minimum(func, xm0, *, xl0=None, xr0=None, xmin=None, xmax=None,
func, xm0, xl0, xr0, xmin, xmax, factor, args, maxiter = temp

xs = (xl0, xm0, xr0)
func, xs, fs, args, shape, dtype = eim._initialize(func, xs, args)
temp = eim._initialize(func, xs, args)
func, xs, fs, args, shape, dtype, xp = temp

xl0, xm0, xr0 = xs
fl0, fm0, fr0 = fs
Expand Down Expand Up @@ -661,4 +663,4 @@ def customize_result(res, shape):
maxiter, func, args, dtype,
pre_func_eval, post_func_eval,
check_termination, post_termination_check,
customize_result, res_work_pairs)
customize_result, res_work_pairs, xp)