Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: array types: add JAX support #20085

Merged
merged 73 commits into from
May 18, 2024
Merged

ENH: array types: add JAX support #20085

merged 73 commits into from
May 18, 2024

Conversation

lucascolley
Copy link
Member

@lucascolley lucascolley commented Feb 13, 2024

Reference issue

Towards gh-18867

What does this implement/fix?

First steps on JAX support. To-do:

Additional information

Can do the same for dask.array once the problems are fixed over at data-apis/array-api-compat#89.

@github-actions github-actions bot added scipy.cluster scipy._lib Meson Items related to the introduction of Meson as the new build system for SciPy array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement labels Feb 13, 2024
@lucascolley lucascolley removed the Meson Items related to the introduction of Meson as the new build system for SciPy label Feb 13, 2024
@lucascolley
Copy link
Member Author

lucascolley commented Feb 14, 2024

(The reason I decided to comment over on the DLPack issue is that I recall a conversation about how portability could be increased if we replace occurrences of np.asarray with {array_api_compat.numpy, np>=2.0}.from_dlpack}. Clearly, portability past libraries which are coercible by np.asarray is very low prio at the minute, but something to consider long-term. Also, DLPack being the idiomatic way to do library-interchange, rather than relying on the array-creation function asarray)

@rgommers
Copy link
Member

Thanks for working on this Lucas. JAX support will be very nice. And a third library with CPU support (after NumPy and PyTorch) will also be good for testing how generic our array API standard support actually is.

Okay, related to the read-only question, it looks like this is the problem you were seeing:

scipy/cluster/hierarchy.py:1038: in linkage
    result = _hierarchy.mst_single_linkage(y, n)
        method     = 'single'
        method_code = 0
        metric     = 'euclidean'
        n          = 6
        optimal_ordering = False
        xp         = <module 'jax.experimental.array_api' from '/home/rgommers/mambaforge/envs/scipy-dev-jax/lib/python3.11/site-packages/jax/experimental/array_api/__init__.py'>
        y          = array([1.48660687, 2.23606798, 1.41421356, 1.41421356, 1.41421356,
       2.28254244, 0.1       , 1.48660687, 1.48660687, 2.23606798,
       1.        , 1.        , 1.41421356, 1.41421356, 0.        ])
_hierarchy.pyx:1015: in scipy.cluster._hierarchy.mst_single_linkage
    ???
<stringsource>:663: in View.MemoryView.memoryview_cwrapper
    ???
<stringsource>:353: in View.MemoryView.memoryview.__cinit__
    ???
E   ValueError: buffer source array is read-only

The problem is that Cython doesn't accept read-only arrays when the signature is a regular memoryview. There's a long discussion about this topic in scikit-learn/scikit-learn#10624. Now that we have Cython 3 though, the fix is simple:

diff --git a/scipy/cluster/_hierarchy.pyx b/scipy/cluster/_hierarchy.pyx
index 814051df2..c59b3de6a 100644
--- a/scipy/cluster/_hierarchy.pyx
+++ b/scipy/cluster/_hierarchy.pyx
@@ -1012,7 +1012,7 @@ def nn_chain(double[:] dists, int n, int method):
     return Z_arr
 
 
-def mst_single_linkage(double[:] dists, int n):
+def mst_single_linkage(const double[:] dists, int n):
     """Perform hierarchy clustering using MST algorithm for single linkage.
 
     Parameters

This makes the tests pass (at least for this issue, I tried with the dendrogram tests only). The dists input to mst_single_linkage isn't modified in-place, so once we tell Cython that by adding const, things are happy.

@lucascolley lucascolley added the Cython Issues with the internal Cython code base label Feb 14, 2024
@lucascolley
Copy link
Member Author

thanks! I've removed the copies and added some consts to the Cython file to get the tests to pass. Still some failures for in-place assignments with indexing but we can circle back to those once we get integration with the test skip infra.

scipy/conftest.py Outdated Show resolved Hide resolved
Copy link
Member

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question I have here, which is probably a question more broadly for the array API: as written, much of the JAX support added here will not work under jax.jit, because it requires converting array objects to host-side buffers, and this is not possible during tracing when the array objects are abstract. JAX has mechanisms for this (namely custom calls and/or pure_callback) but the array API doesn't seem to have much consideration for this kind of library structure. Unfortunately, I think this will severely limit the usefulness of these kinds of implementations. I wonder if the array API could consider this kind of limitation?

@rgommers
Copy link
Member

One question I have here, which is probably a question more broadly for the array API: as written, much of the JAX support added here will not work under jax.jit, because it requires converting array objects to host-side buffers,

Do you mean for testing purposes, or for library code? For the latter: we should never do device transfers like GPU->host memory under the hood. The array API standard design was careful to not include that. It wasn't even possible at all until very recently, when a way was added to do it with DLPack (for testing purposes).

If you mean "convert to numpy.ndarray before going into Cython/C/C++/Fortran code inside SciPy, then yes that is happening. That's kinda not an array API standard issue, because it's leaving Python - and that's a very different problem. To avoid compiled code inside SciPy - which indeed won't work with any JIT compiler unless that JIT is specifically aware of the SciPy functionality being called - it'd be necessary to have either a pure Python path (slow) or a matching API inside JAX that can be called (jax.scipy has some that we should be deferring to here).

and this is not possible during tracing when the array objects are abstract. JAX has mechanisms for this (namely custom calls and/or pure_callback) but the array API doesn't seem to have much consideration for this kind of library structure. Unfortunately, I think this will severely limit the usefulness of these kinds of implementations. I wonder if the array API could consider this kind of limitation?

JIT compilers were explicitly considered, and nothing in the standard should be JIT-unfriendly, except for the few clearly marked as data-dependent output shapes and the few dunder methods that are also problematic for lazy arrays.

@lucascolley
Copy link
Member Author

lucascolley commented Feb 27, 2024

Do you mean for testing purposes, or for library code? For the latter

If this is what you meant, x-ref the 'Dispatching Mechanism' section of gh-18286

@jakevdp
Copy link
Member

jakevdp commented Feb 27, 2024

I mean for actual user-level code: most of the work here will be more-or-less useless for JAX users because array conversions via dlpack cannot be done under JIT without some sort of callback mechanism.

@rgommers
Copy link
Member

Okay, I had a look at https://jax.readthedocs.io/en/latest/tutorials/external-callbacks.html and understand what you mean now. jax.pure_callback looks quite interesting indeed. I wasn't familiar with it, but it looks like that may actually solve an important puzzle in dealing with compiled code. It doesn't support GPU execution or auto-differentiation, but getting jax.jit and jax.vmap to work would be a significant step forward.

It looks fairly straightforward to support (disclaimer: I haven't tried it yet). It'd be taking this current code pattern:

# inside some Python-level scipy function with array API standard support:

x = np.asarray(x)
result = call_some_compiled_code(x)
result = xp.asarray(result)  # back to original array type

and replacing it with something like (untested):

def call_compiled_code_helper(x, xp):  # needs *args, *kwargs too
    if is_jax(x):
        result_shape_dtypes = ... # TODO: figure out how to construct the needed PyTree here
        result = jax.pure_callback(call_some_compiled_code, result_shape_dtypes, x)
    else:
        x = np.asarray(x)
        result = call_some_compiled_code(x)
        result = xp.asarray(result)

Use of a utility function like call_compiled_code_helper may even make the code shorter and easier to understand. It seems feasible at first sight.

It's interesting that jax.pure_callback transforms JAX arrays to NumPy arrays under the hood already.

@jakevdp
Copy link
Member

jakevdp commented Feb 27, 2024

Yeah, something like that is what I had in mind, though pure_callback is probably not the right mechanism. JAX doesn't currently have an easy pure-callback-like mechanism for executing custom kernels on device, without the round-trip to host implied by pure_callback. I wonder if this kind of thing will be an issue for other array API libraries?

@rgommers
Copy link
Member

I wonder if this kind of thing will be an issue for other array API libraries?

It is (depending on your defintion of "issue") because there's no magic bullet that will do something like take some native function implemented in C/Fortran/Cython inside SciPy and make that run on GPU.

The basic state of things is:

  • functions implemented in pure Python are unproblematic, and with array API support get to run on GPU/TPU, gain autograd support, etc.
    • with a few exceptions: functions using unique and other data-dependent shapes, iterative algorithms with a stopping/branching criterion that requires eager evaluation, functions using in-place operations.
  • as soon as you hit compiled code, things get harder. everything that worked before with numpy only will still work, but autograd and GPU execution won't

JAX doesn't currently have an easy pure-callback-like mechanism for executing custom kernels on device, without the round-trip to host implied by pure_callback.

In a generic library like SciPy it's almost impossible to support custom kernels on device. Our choices for arrays that don't live on host memory are:

  • find a matching function in the other library. e.g., we can explicitly defer to everything in jax.scipy, cupyx.scipy and torch.fft/linalg/special,
  • raise an exception
  • do an automatic to/from host roundtrip (we haven't considered this a good idea before, since data transfers can be very expensive - but apparently that's what pure_callback prefers over raising)

@lucascolley
Copy link
Member Author

lucascolley commented Feb 28, 2024

I gave adding Dask another shot just now, but unfortunately things are missing from dask.array like float64, which makes most of our test code fail. Perhaps we will have to change to using the wrapped namespaces throughout the tests (this is awkward because we still need to imitate an array from the unwrapped namespace being input).

x-ref dask/dask#10387 (comment)

@rgommers
Copy link
Member

I'd suggest keeping this PR focused on JAX and getting that merged first. That makes it easier to see (also in the future) what had to be done only for JAX. And if we're going to experiment a bit with jax.jit, this PR may grow already.

rgommers pushed a commit that referenced this pull request May 18, 2024
Addresses some of my points at:
#20085 (review)
and seems to fix about 55 GPU-based array API test failures

Co-authored-by: Matt Haberland <[email protected]>
@@ -207,7 +205,6 @@ def test_mlab_linkage_conversion_empty(self, xp):
xp_assert_equal(from_mlab_linkage(X), X)
xp_assert_equal(to_mlab_linkage(X), X)

@skip_xp_backends(cpu_only=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_mlab_linkage converts with np.asarray, so I'll put these back.

@rgommers
Copy link
Member

@lucascolley FYI some of the cpu_only=True tests that pass with JAX on GPU are doing so because np.asarray(a_jax_cuda_array) works. However, it is very inefficient, the code is often not correct anyway with JAX because the reverse xp.asarray call doesn't put data back on the GPU, and it won't work for either CuPy or PyTorch. So I'll undo such test changes.

I found that having an environment with both JAX and CuPy and testing with -b all is a nice way to test, since they have quite different constraints. PyTorch is harder to install in the same env as JAX, but if things work for both JAX and CuPy then they'll most likely work for PyTorch as well.

@lucascolley
Copy link
Member Author

I found that having an environment with both JAX and CuPy and testing with -b all is a nice way to test, since they have quite different constraints. PyTorch is harder to install in the same env as JAX, but if things work for both JAX and CuPy then they'll most likely work for PyTorch as well.

Nice. Moving forward I will probably have one env with JAX + CuPy and another with PyTorch + CuPy + array-api-strict, and test with both. Things will be easier once I'm back with a GPU.

@rgommers
Copy link
Member

The one Windows failure is unrelated:

 FAILED scipy\optimize\tests\test_constraint_conversion.py::TestNewToOld::test_individual_constraint_objects

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, time to give it a go. Thanks a lot @lucascolley and all reviewers!

Follow-up steps include:

  1. deal with item/slice assignment, reducing/removing skips related to that
  2. look at a callback mechanism to make jax.jit work
  3. check how things look on TPU (e.g. in a Kaggle notebook, see discussion higher up)

I plan to have a look at 1 and 2 tonight.

Unrelated to JAX follow-ups:

  • a few more CuPy test failures to deal with in main
  • deal with other test failures related to nan_policy

@rgommers rgommers merged commit 7192a1c into scipy:main May 18, 2024
@lucascolley lucascolley deleted the jax branch May 18, 2024 10:21
@lucascolley
Copy link
Member Author

thanks Ralf and all reviewers for all of the help here! I plan to have a look at Dask in a few months' time, but anyone else, feel free to tackle it if you get to it before me.

A reminder that gh-19900 looks ready to me and should help eliminate some of the GPU failures Tyler was seeing. But no rush if it looks like more work is needed.

FYI @izaid , scipy._lib._array_api.scipy_namespace_for exists now.

I'm ~1 week out from finals now, so I'll not be working on any PRs for a while. See you on the other side!

@mdhaber
Copy link
Contributor

mdhaber commented May 18, 2024

What are the CuPy and nan_policy failures? I can probably fix them today.

@rgommers
Copy link
Member

rgommers commented May 18, 2024

Good luck with your finals Lucas!

What are the CuPy and nan_policy failures? I can probably fix them today.

CuPy failures are taken care of in gh-19900. They were:

_______________________________ TestFFTThreadSafe.test_ihfft[cupy] ________________________________
scipy/fft/tests/test_basic.py:460: in test_ihfft
    self._test_mtsame(fft.ihfft, a, xp=xp)
        a          = array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]])
        self       = <scipy.fft.tests.test_basic.TestFFTThreadSafe object at 0x754efabfef60>
        xp         = <module 'cupy' from '/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/__init__.py'>
scipy/fft/tests/test_basic.py:434: in _test_mtsame
    q.get(timeout=5), expected,
        args       = (array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
   ....,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]]),)
        expected   = array([[1.-0.j, 0.-0.j, 0.-0.j, ..., 0.+0.j, 0.+0.j, 0.-0.j],
       [1.-0.j, 0.-0.j, 0.-0.j, ..., 0.+0.j, 0.+0.j, 0.-...  [1.-0.j, 0.-0.j, 0.-0.j, ..., 0.+0.j, 0.+0.j, 0.-0.j],
       [1.-0.j, 0.-0.j, 0.-0.j, ..., 0.+0.j, 0.+0.j, 0.-0.j]])
        func       = <uarray multimethod 'ihfft'>
        i          = 0
        q          = <queue.Queue object at 0x754ec6b60290>
        self       = <scipy.fft.tests.test_basic.TestFFTThreadSafe object at 0x754efabfef60>
        t          = [<Thread(Thread-354 (worker), stopped 128979476416192)>, <Thread(Thread-355 (worker), stopped 128979444958912)>, <Thre...>, <Thread(Thread-358 (worker), stopped 128979677742784)>, <Thread(Thread-359 (worker), stopped 128979667257024)>, ...]
        worker     = <function TestFFTThreadSafe._test_mtsame.<locals>.worker at 0x754ed013be20>
        xp         = <module 'cupy' from '/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/__init__.py'>
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/queue.py:179: in get
    raise Empty
E   _queue.Empty
        block      = True
        endtime    = 145.408170417
        remaining  = -0.0001208719999965524
        self       = <queue.Queue object at 0x754ec6b60290>
        timeout    = 5

During handling of the above exception, another exception occurred:
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/_pytest/runner.py:341: in from_call
    result: Optional[TResult] = func()
        cls        = <class '_pytest.runner.CallInfo'>
        duration   = 5.902584181999998
        excinfo    = <ExceptionInfo PytestUnhandledThreadExceptionWarning('Exception in thread Thread-368 (worker)\n\nTraceback (most recen.../cuda/cufft.pyx", line 169, in cupy.cuda.cufft.check_result\ncupy.cuda.cufft.CuFFTError: CUFFT_EXEC_FAILED\n') tblen=9>
        func       = <function call_and_report.<locals>.<lambda> at 0x754f987cff60>
        precise_start = 139.640223733
        precise_stop = 145.542807915
        reraise    = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
        result     = None
        start      = 1716057866.3666894
        stop       = 1716057872.2692754
        when       = 'call'
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/_pytest/runner.py:241: in <lambda>
    lambda: runtest_hook(item=item, **kwds), when=when, reraise=reraise
        item       = <Function test_ihfft[cupy]>
        kwds       = {}
        runtest_hook = <HookCaller 'pytest_runtest_call'>
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/pluggy/_hooks.py:513: in __call__
    return self._hookexec(self.name, self._hookimpls.copy(), kwargs, firstresult)
        firstresult = False
        kwargs     = {'item': <Function test_ihfft[cupy]>}
        self       = <HookCaller 'pytest_runtest_call'>
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/pluggy/_manager.py:120: in _hookexec
    return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
        firstresult = False
        hook_name  = 'pytest_runtest_call'
        kwargs     = {'item': <Function test_ihfft[cupy]>}
        methods    = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/home/rgommers/mambaforge/envs/scipy-dev-jax-cu...=None>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x754fd2f5fbc0>>, ...]
        self       = <_pytest.config.PytestPluginManager object at 0x754fdc39a270>
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/_pytest/threadexception.py:87: in pytest_runtest_call
    yield from thread_exception_runtest_hook()
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/_pytest/threadexception.py:77: in thread_exception_runtest_hook
    warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg))
E   pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-368 (worker)
E   
E   Traceback (most recent call last):
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/threading.py", line 1073, in _bootstrap_inner
E       self.run()
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/threading.py", line 1010, in run
E       self._target(*self._args, **self._kwargs)
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/fft/tests/test_basic.py", line 419, in worker
E       q.put(func(*args))
E             ^^^^^^^^^^^
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/fft/_backend.py", line 28, in __ua_function__
E       return fn(*args, **kwargs)
E              ^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/fft/_basic_backend.py", line 90, in ihfft
E       return _execute_1D('ihfft', _pocketfft.ihfft, x, n=n, axis=axis, norm=norm,
E              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/fft/_basic_backend.py", line 34, in _execute_1D
E       return xp_func(x, n=n, axis=axis, norm=norm)
E              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/_lib/array_api_compat/_internal.py", line 28, in wrapped_f
E       return f(*args, xp=xp, **kwargs)
E              ^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/_lib/array_api_compat/common/_fft.py", line 147, in ihfft
E       res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
E             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/fft/_fft.py", line 1050, in ihfft
E       return rfft(a, n, axis, _swap_direction(norm)).conj()
E              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/fft/_fft.py", line 840, in rfft
E       return _fft(a, (n,), (axis,), norm, cufft.CUFFT_FORWARD, 'R2C')
E              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/fft/_fft.py", line 248, in _fft
E       a = _exec_fft(a, direction, value_type, norm, axes[-1], overwrite_x)
E           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/fft/_fft.py", line 191, in _exec_fft
E       plan.fft(a, out, direction)
E     File "cupy/cuda/cufft.pyx", line 500, in cupy.cuda.cufft.Plan1d.fft
E     File "cupy/cuda/cufft.pyx", line 520, in cupy.cuda.cufft.Plan1d._single_gpu_fft
E     File "cupy/cuda/cufft.pyx", line 1145, in cupy.cuda.cufft.execD2Z
E     File "cupy/cuda/cufft.pyx", line 169, in cupy.cuda.cufft.check_result
E   cupy.cuda.cufft.CuFFTError: CUFFT_EXEC_FAILED
        cm         = <_pytest.threadexception.catch_threading_exception object at 0x754efabff440>
        msg        = 'Exception in thread Thread-368 (worker)\n\nTraceback (most recent call last):\n  File "/home/rgommers/mambaforge/envs...File "cupy/cuda/cufft.pyx", line 169, in cupy.cuda.cufft.check_result\ncupy.cuda.cufft.CuFFTError: CUFFT_EXEC_FAILED\n'
        thread_name = 'Thread-368 (worker)'
===================================== short test summary info =====================================
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_ifft[cupy] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-112 (worker)
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_rfft[cupy] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-176 (worker)
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_ihfft[cupy] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-368 (worker)

nan_policy failures I gave higher up, will circle back to those after reviewing gh-19900. EDIT: see gh-20748.

rgommers added a commit to rgommers/scipy that referenced this pull request May 19, 2024
This makes things work with JAX, at a slight readability cost.
Follow up to scipy#20085.
@rgommers
Copy link
Member

rgommers commented May 19, 2024

deal with item/slice assignment, reducing/removing skips related to that

I worked some more on this, adding at_set/at_add etc. for in-place equivalents: https://github.com/scipy/scipy/compare/main...rgommers:scipy:array-types-inplace-ops?expand=1. It's a slightly readability hit to replace Z[i//2,1] = -2 with Z = at_set(Z, (i//2, 1), -2), but acceptable in many places (an opt-in mode for JAX to recognize regular in-place syntax would be way better though).

Using it, some good news and some less good. I could get cluster.whiten to work with jax.jit on CPU and GPU with some minor tweaks. And it helps performance:

    whiten_jit(face).block_until_ready()  # do the JIT compilation
    face_gpu = face  # JAX defaults to GPU if that is available
    face_cpu = jax.device_put(face, jax.devices('cpu')[0])
    face_np = np.asarray(face)

    %timeit cluster.vq.whiten(face_np)   #  22 ms
    %timeit cluster.vq.whiten(face_cpu)  #   6 ms
    %timeit whiten_jit(face_cpu)         #   3 ms
    %timeit cluster.vq.whiten(face_gpu)  # 700 us
    %timeit whiten_jit(face_gpu)         # 275 us

The less good news is that the at[idx].set() syntax still doesn't work for boolean indexing:

>>> import jax.numpy as jnp
>>> import jax
>>>
>>> def func(x, idx, value):
...     return x.at[idx].set(value)
...
>>> func_jit = jax.jit(func)
>>>
>>> x = jnp.arange(5)
>>> idx = x < 3
>>>
>>> func(x, idx, 99)
Array([99, 99, 99,  3,  4], dtype=int32)
>>> func_jit(x, idx, 99)
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

The explanation under Boolean indexing into JAX arrays https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError isn't quite satisfactory. There are no dynamic shapes here, so it could work just fine. If the answer is to always use where then it's just missing support for syntax that could be translated to jnp.where internally. So it seems like we need further branching under is_jax like this:

def at_set(
        x : Array,
        idx: Array | int | slice,
        val: Array | int | float | complex,
        *,
        xp: ModuleType | None = None,
    ) -> Array:
    """In-place update. Use only if no views are involved."""
    xp = array_namespace(x) if xp is None else xp
    if is_jax(xp):
        if xp.isdtype(idx.dtype, 'bool'):
            x = xp.where(idx, x, val)
        else:
            x = x.at[idx].set(val)
    else:
        x[idx] = val
    return x

Which is slower - and also I'm not sure if jax.jit is then going to complain about the if xp.isdtype line because of Python control flow with an array involved EDIT: that works, with a tweak (even slower):

    if is_jax(xp):
        if hasattr(idx, 'dtype') and xp.isdtype(idx.dtype, 'bool'):
            x = xp.where(idx, x * val, x)
        else:
            x = x.at[idx].multiply(val)
    else:
        x[idx] *= val

I'll look at it some more - this may be better suited for data-apis/array-api#609.

@mdhaber
Copy link
Contributor

mdhaber commented May 19, 2024

I think it would be worth adding something that works for now, even if it's not great. It would avoid all the test skips and make it more obvious what capabilities we need. Once something better comes along, it will be easy to replace. It's probably better than using where, which we're tempted to use otherwise.

@rgommers
Copy link
Member

I think it would be worth adding something that works for now, even if it's not great.

Yeah maybe - I don't want to go too fast though, and add a bunch of code we may regret. Looks like the new version (I edited my comment and pushed a new commit) works though, and is still very fast with JAX.

It's probably better than using where, which we're tempted to use otherwise.

Let's make sure not to do things like that. Using where could potentially be bad for performance with numpy, which would not be helpful. Skips are better for now.

@jakevdp
Copy link
Member

jakevdp commented May 20, 2024

@rgommers FYI I managed to implement the scalar boolean scatter in JAX, and it will be available in the next release. Turns out we had all the necessary logic there already – I just needed to put it together! google/jax#21305

@rgommers
Copy link
Member

Great! Thanks @jakevdp. Looks like a small patch that I can try out pretty easily on top of JAX 0.4.28 - will give it a go later this week.

(note to self, since comments are hard to find in this PR: the relevant comment here is #20085 (comment))

@mdhaber
Copy link
Contributor

mdhaber commented Jun 11, 2024

Resurrecting the conversation about #20085 (comment) based on gh-20935.

@rgommers can we add that function that mutates an array at boolean indices where possible and copies when necessary (e.g. JAX)? If we regret it later, we can just change the definition of the function. The only downside I see would be the overhead of an extra function call for non-JAX arrays. The potential upside is JAX support in many functions. If the experiment fails completely or a new array-API standard functionality is made available, we can revert or change wherever the new function is used. Adding/removing all these skips has some cost, too, and I would prefer to with or at least note any other JAX incompatibilities while we're still working on converting a function rather than having to come back later.

@lucascolley @jakevdp anything to add / change?

@jakevdp
Copy link
Member

jakevdp commented Jun 11, 2024

Late here but responding to #20085 (comment)

There are no dynamic shapes here, so it could work just fine.

a.at[mask].set(arr) in general does require dynamic shapes: it is only valid if arr is broadcast-compatible with the number of True values in mask, and the number of True values is dynamic.

I suspect what you have in mind is the special case when the size of arr is 1, and so we know a priori that it broadcasts with an array of any size. Still, the semantics of lax.scatter require actually instantiating that array. I've explored the idea of overloading JAX's arr.at[].set() to lower to lax.select rather than lax.scatter in this particular special case, but overall it seems like it adds undue complexity to the implementation and to the user's mental model of what this function does.

@mdhaber
Copy link
Contributor

mdhaber commented Jun 11, 2024

@jakevdp would you be willing to open a PR so we have something concrete to discuss? I'd be happy to contribute to the branch by using the function in stats functions and showing that it allows us to remove test skips without performance costs to non-JAX arrays.

@rgommers
Copy link
Member

Thanks for the answer @jakevdp!

I suspect what you have in mind is the special case when the size of arr is 1

No, not really. To me there's a fundamental difference between the semantics of a function, and implementation details of it. For b = a.at[mask].set(vals), it doesn't matter whether vals is scalar or an array with a compatible shape, or how many True values there are in mask. The shape of b is the same as that of a in all cases, so the semantics of the function do not include any dynamic shapes. This is very different from y = x[mask], where y.shape depends on the values in mask.

I would not call the requirement that mask and vals are broadcast-compatible a dynamic shape. It's more like input validation that may raise an error (or propagate nan's if you can't raise an error), just like there are functions that don't deal with inf/nan, negative values, singular matrices, or any other case of input values not meeting what's required for some function. The output shape itself cannot change.

I've explored the idea of overloading JAX's arr.at[].set() to lower to lax.select rather than lax.scatter in this particular special case, but overall it seems like it adds undue complexity to the implementation and to the user's mental model of what this function does.

It'd be moving the "if boolean indices, then use where or select under the hood" if-else logic from user-land to a single place inside JAX I think. So my expectation is that it would reduce the complexity of the mental model, at least when coming from NumPy or PyTorch. Because now it seems like we have to do the following translation for in-place ops: (for val is either a Python scalar or an array):

  • x += val: works
  • x[int_indices] += val: TypeError -> use x.at[int_indices].add(val)
  • x[bool_mask] += val: TypeError -> use x.where(bool_mask, x, x + val) if val is scalar and ?? if it's an array

Note also that the error message for the boolean case is misleading, since it recommends a .at method but we need where instead:

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

It would be useful to have an improved error message here, and mention the boolean indexing case in https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.

I'm actually curious what your preference here is. It seems like the at_set/at_add etc. implementations I posted above are still non-JAX-like. But replacing a line like x[mask] += 2 with xp.where(mask, x, x+2) is both harder to read and a lot less efficient with NumPy & co. And for x[mask] += y it gets worse (don't even know how to properly do that in JAX, and it'd lose jit-ability since now you're forced to break up the op and then it does use dynamic shapes). So I don't really know what the best way forward here is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) Cython Issues with the internal Cython code base enhancement A new feature or improvement scipy.cluster scipy.fft scipy._lib scipy.special scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants