Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: optimize._chandrupatla: add array API support #20689

Merged
merged 13 commits into from May 13, 2024
Merged

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented May 10, 2024

Reference issue

gh-7242

What does this implement/fix?

This adds array API support to scipy.optimize._chandrupatla and paves the way toward adding array API support to the other elementwise iterative methods (e.g. _differentiate, _tanhsinh, _nsum, _chandrupatla_minimize, and the bracket finders). I'll propose making the rootfinder, minimizer, and bracket finders public shortly.

Additional information

The performance of this function compared to the existing bracketing rootfinders is worst when function evaluations are inexpensive, so I compared the performance between

  • brentq
  • _chandrupatla (NumPy)
  • _chandrupatla (PyTorch, CPU)
  • _chandrupatla (CuPy)

when finding the root of xp.cos(x) - p for many values of p. The bracket is always $[0, \pi]$. brentq's default absolute x-tolerance is 2e-12, so I set _chandrupatla to 1e-12 to be more than fair, and I confirmed that both solvers meet their respective tolerance. _chandrupatla finds the roots in a single vectorized call whereas brentq loops with a list comprehension.

image

# import time
# import numpy as np
# import cupy as cp
# import torch
# import matplotlib.pyplot as plt
#
# rng = np.random.default_rng(1638083107694713882823079058616272161)
# from scipy import optimize
# from scipy.optimize._chandrupatla import _chandrupatla
#
# xp = torch
# n = 10
# n = int(n)
# a = xp.asarray(0)
# b = xp.asarray(np.pi)
# ps = xp.asarray(rng.random(size=n) * 2 - 1, dtype=xp.float64)
#
# def f(x, p):
#     return xp.cos(x) - p
#
# tic = time.perf_counter()
# res = _chandrupatla(f, a, b, args=(ps,), xatol=1e-12)
# toc = time.perf_counter()

import time
import numpy as np
import cupy as cp
import torch
import matplotlib.pyplot as plt

rng = np.random.default_rng(1638083107694713882823079058616272161)
from scipy import optimize
from scipy.optimize._chandrupatla import _chandrupatla

ns = np.logspace(0, 6, 30)

times = {np: [], cp: [], torch: [],  'brentq':[]}

for xp in [np, cp, torch, 'brentq']:
    brentq = False
    if xp == 'brentq':
        xp = np
        brentq = True

    for n in ns:
        n = int(n)
        a = xp.asarray(0)
        b = xp.asarray(np.pi)
        ps = xp.asarray(rng.random(size=n) * 2 - 1, dtype=xp.float64)

        def f(x, p):
            return xp.cos(x) - p

        if brentq:
            tic = time.perf_counter()
            ref = [optimize.brentq(f, a, b, args=(p,)) for p in ps]
            toc = time.perf_counter()
            times['brentq'].append(toc - tic)
            np.testing.assert_allclose(ref, np.arccos(ps), atol=2e-12)
            continue

        tic = time.perf_counter()
        res = _chandrupatla(f, a, b, args=(ps,), xatol=1e-12)
        toc = time.perf_counter()
        times[xp].append(toc - tic)
        np.testing.assert_allclose(cp.asnumpy(res.x), np.arccos(cp.asnumpy(ps)), atol=1e-12)

plt.loglog(ns, times[np], label='np')
plt.loglog(ns, times[cp], label='cp')
plt.loglog(ns, times[torch], label='torch')
plt.loglog(ns, times['brentq'], label='brentq')
plt.xlabel('number of roots')
plt.ylabel('execution time (s)')
plt.title('Root of `xp.cos(x) - p`')
plt.legend()
plt.show()

The function has a lot of overhead due to bells and whistles (e.g. input validation with nice error messages, rich result object, callback function support, etc.). But for solving a lot of equations, the overhead of function calls eventually becomes problematic for brentq.

In this case, you could probably get better performance with cython_optimize, but for more expensive functions with overhead, like finding the argus(1) distribution ppf (#17719 (comment)), the advantage is much more pronounced.

image

import time
import numpy as np
import cupy as cp
import torch
import matplotlib.pyplot as plt

rng = np.random.default_rng(1638083107694713882823079058616272161)
from scipy import optimize, stats
from scipy.optimize._chandrupatla import _chandrupatla

ns = np.logspace(0, 4, 30)

times = {np: [], cp: [], torch: [], 'brentq':[]}

for xp in [np, 'brentq']:
    brentq = False
    if xp == 'brentq':
        xp = np
        brentq = True

    for n in ns:
        n = int(n)
        a = 0.001
        b = 0.999
        ps = np.linspace(0.005, 0.995, n)

        dist = stats.argus(1)
        def f(x, p):
            return dist.cdf(x) - p

        if brentq:
            tic = time.perf_counter()
            ref = [optimize.brentq(f, a, b, args=(p,)) for p in ps]
            toc = time.perf_counter()
            times['brentq'].append(toc - tic)
            continue

        tic = time.perf_counter()
        res = _chandrupatla(f, a, b, args=(ps,), xatol=1e-12)
        toc = time.perf_counter()
        times[xp].append(toc - tic)

plt.loglog(ns, times[np], label='np')
# plt.loglog(ns, times[cp], label='cp')
plt.loglog(ns, times['brentq'], label='brentq')
plt.xlabel('number of probabilities')
plt.ylabel('execution time (s)')
plt.title('Root of `argus(1).cdf(x) - p`')
plt.legend()
plt.show()

newton is definitely faster than this function, but it should be - the algorithm converges faster. (The advantage of bracketing methods is that convergence is guaranteed if the bracket is valid.) We can add that as another method with the same framework.

The tests are quite strict; I'm ironing out a few failures with alternative backends. Done if CI looks good.

The function currently uses fancy indexing assignment. I can investigate working around that for strict array API support later.

@mdhaber mdhaber added enhancement A new feature or improvement scipy.optimize array types Items related to array API support and input array validation (see gh-18286) labels May 10, 2024
@mdhaber mdhaber requested a review from tupui May 10, 2024 09:01
@tupui tupui self-assigned this May 10, 2024
Copy link
Member

@tupui tupui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I just have a few questions, the rest is pretty straightforward and LGTM 👍

scipy/_lib/_array_api.py Show resolved Hide resolved
scipy/_lib/_elementwise_iterative_method.py Show resolved Hide resolved
scipy/optimize/_chandrupatla.py Show resolved Hide resolved
scipy/optimize/_chandrupatla.py Show resolved Hide resolved
scipy/optimize/_chandrupatla.py Show resolved Hide resolved
@mdhaber mdhaber marked this pull request as ready for review May 12, 2024 20:28
@mdhaber mdhaber requested a review from tupui May 12, 2024 20:28
Copy link
Member

@tupui tupui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last changes LGTM, letting the CI run and then I think it's good to merge.

@mdhaber
Copy link
Contributor Author

mdhaber commented May 13, 2024

Ok! Remaining failures seem unrelated - a package install conflict and slow tests, hopefully one of which is a temporary glitch.

@tupui
Copy link
Member

tupui commented May 13, 2024

Ark you have conflicts now

scipy/optimize/_bracket.py Outdated Show resolved Hide resolved
@tupui tupui merged commit be0d426 into scipy:main May 13, 2024
28 of 31 checks passed
@tupui tupui added this to the 1.14.0 milestone May 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy._lib scipy.optimize
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants