-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: array types: add JAX support #20085
Conversation
(The reason I decided to comment over on the DLPack issue is that I recall a conversation about how portability could be increased if we replace occurrences of |
Thanks for working on this Lucas. JAX support will be very nice. And a third library with CPU support (after NumPy and PyTorch) will also be good for testing how generic our array API standard support actually is. Okay, related to the read-only question, it looks like this is the problem you were seeing:
The problem is that Cython doesn't accept read-only arrays when the signature is a regular memoryview. There's a long discussion about this topic in scikit-learn/scikit-learn#10624. Now that we have Cython 3 though, the fix is simple: diff --git a/scipy/cluster/_hierarchy.pyx b/scipy/cluster/_hierarchy.pyx
index 814051df2..c59b3de6a 100644
--- a/scipy/cluster/_hierarchy.pyx
+++ b/scipy/cluster/_hierarchy.pyx
@@ -1012,7 +1012,7 @@ def nn_chain(double[:] dists, int n, int method):
return Z_arr
-def mst_single_linkage(double[:] dists, int n):
+def mst_single_linkage(const double[:] dists, int n):
"""Perform hierarchy clustering using MST algorithm for single linkage.
Parameters This makes the tests pass (at least for this issue, I tried with the |
thanks! I've removed the copies and added some |
00e27e8
to
77aaebd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One question I have here, which is probably a question more broadly for the array API: as written, much of the JAX support added here will not work under jax.jit
, because it requires converting array objects to host-side buffers, and this is not possible during tracing when the array objects are abstract. JAX has mechanisms for this (namely custom calls and/or pure_callback) but the array API doesn't seem to have much consideration for this kind of library structure. Unfortunately, I think this will severely limit the usefulness of these kinds of implementations. I wonder if the array API could consider this kind of limitation?
Do you mean for testing purposes, or for library code? For the latter: we should never do device transfers like GPU->host memory under the hood. The array API standard design was careful to not include that. It wasn't even possible at all until very recently, when a way was added to do it with DLPack (for testing purposes). If you mean "convert to
JIT compilers were explicitly considered, and nothing in the standard should be JIT-unfriendly, except for the few clearly marked as data-dependent output shapes and the few dunder methods that are also problematic for lazy arrays. |
If this is what you meant, x-ref the 'Dispatching Mechanism' section of gh-18286 |
I mean for actual user-level code: most of the work here will be more-or-less useless for JAX users because array conversions via dlpack cannot be done under JIT without some sort of callback mechanism. |
Okay, I had a look at https://jax.readthedocs.io/en/latest/tutorials/external-callbacks.html and understand what you mean now. It looks fairly straightforward to support (disclaimer: I haven't tried it yet). It'd be taking this current code pattern: # inside some Python-level scipy function with array API standard support:
x = np.asarray(x)
result = call_some_compiled_code(x)
result = xp.asarray(result) # back to original array type and replacing it with something like (untested): def call_compiled_code_helper(x, xp): # needs *args, *kwargs too
if is_jax(x):
result_shape_dtypes = ... # TODO: figure out how to construct the needed PyTree here
result = jax.pure_callback(call_some_compiled_code, result_shape_dtypes, x)
else:
x = np.asarray(x)
result = call_some_compiled_code(x)
result = xp.asarray(result) Use of a utility function like It's interesting that |
Yeah, something like that is what I had in mind, though |
It is (depending on your defintion of "issue") because there's no magic bullet that will do something like take some native function implemented in C/Fortran/Cython inside SciPy and make that run on GPU. The basic state of things is:
In a generic library like SciPy it's almost impossible to support custom kernels on device. Our choices for arrays that don't live on host memory are:
|
I gave adding Dask another shot just now, but unfortunately things are missing from |
I'd suggest keeping this PR focused on JAX and getting that merged first. That makes it easier to see (also in the future) what had to be done only for JAX. And if we're going to experiment a bit with |
Addresses some of my points at: #20085 (review) and seems to fix about 55 GPU-based array API test failures Co-authored-by: Matt Haberland <[email protected]>
@@ -207,7 +205,6 @@ def test_mlab_linkage_conversion_empty(self, xp): | |||
xp_assert_equal(from_mlab_linkage(X), X) | |||
xp_assert_equal(to_mlab_linkage(X), X) | |||
|
|||
@skip_xp_backends(cpu_only=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from_mlab_linkage
converts with np.asarray
, so I'll put these back.
@lucascolley FYI some of the I found that having an environment with both JAX and CuPy and testing with |
Nice. Moving forward I will probably have one env with JAX + CuPy and another with PyTorch + CuPy + array-api-strict, and test with both. Things will be easier once I'm back with a GPU. |
The one Windows failure is unrelated:
|
[skip ci]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, time to give it a go. Thanks a lot @lucascolley and all reviewers!
Follow-up steps include:
- deal with item/slice assignment, reducing/removing skips related to that
- look at a callback mechanism to make
jax.jit
work - check how things look on TPU (e.g. in a Kaggle notebook, see discussion higher up)
I plan to have a look at 1 and 2 tonight.
Unrelated to JAX follow-ups:
- a few more CuPy test failures to deal with in
main
- deal with other test failures related to
nan_policy
thanks Ralf and all reviewers for all of the help here! I plan to have a look at Dask in a few months' time, but anyone else, feel free to tackle it if you get to it before me. A reminder that gh-19900 looks ready to me and should help eliminate some of the GPU failures Tyler was seeing. But no rush if it looks like more work is needed. FYI @izaid , I'm ~1 week out from finals now, so I'll not be working on any PRs for a while. See you on the other side! |
What are the CuPy and |
Good luck with your finals Lucas!
CuPy failures are taken care of in gh-19900. They were:
|
This makes things work with JAX, at a slight readability cost. Follow up to scipy#20085.
I worked some more on this, adding Using it, some good news and some less good. I could get
The less good news is that the >>> import jax.numpy as jnp
>>> import jax
>>>
>>> def func(x, idx, value):
... return x.at[idx].set(value)
...
>>> func_jit = jax.jit(func)
>>>
>>> x = jnp.arange(5)
>>> idx = x < 3
>>>
>>> func(x, idx, 99)
Array([99, 99, 99, 3, 4], dtype=int32)
>>> func_jit(x, idx, 99)
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5]) The explanation under Boolean indexing into JAX arrays https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError isn't quite satisfactory. There are no dynamic shapes here, so it could work just fine. If the answer is to always use def at_set(
x : Array,
idx: Array | int | slice,
val: Array | int | float | complex,
*,
xp: ModuleType | None = None,
) -> Array:
"""In-place update. Use only if no views are involved."""
xp = array_namespace(x) if xp is None else xp
if is_jax(xp):
if xp.isdtype(idx.dtype, 'bool'):
x = xp.where(idx, x, val)
else:
x = x.at[idx].set(val)
else:
x[idx] = val
return x Which is slower - if is_jax(xp):
if hasattr(idx, 'dtype') and xp.isdtype(idx.dtype, 'bool'):
x = xp.where(idx, x * val, x)
else:
x = x.at[idx].multiply(val)
else:
x[idx] *= val I'll look at it some more - this may be better suited for data-apis/array-api#609. |
I think it would be worth adding something that works for now, even if it's not great. It would avoid all the test skips and make it more obvious what capabilities we need. Once something better comes along, it will be easy to replace. It's probably better than using |
Yeah maybe - I don't want to go too fast though, and add a bunch of code we may regret. Looks like the new version (I edited my comment and pushed a new commit) works though, and is still very fast with JAX.
Let's make sure not to do things like that. Using |
@rgommers FYI I managed to implement the scalar boolean scatter in JAX, and it will be available in the next release. Turns out we had all the necessary logic there already – I just needed to put it together! google/jax#21305 |
Great! Thanks @jakevdp. Looks like a small patch that I can try out pretty easily on top of JAX 0.4.28 - will give it a go later this week. (note to self, since comments are hard to find in this PR: the relevant comment here is #20085 (comment)) |
Resurrecting the conversation about #20085 (comment) based on gh-20935. @rgommers can we add that function that mutates an array at boolean indices where possible and copies when necessary (e.g. JAX)? If we regret it later, we can just change the definition of the function. The only downside I see would be the overhead of an extra function call for non-JAX arrays. The potential upside is JAX support in many functions. If the experiment fails completely or a new array-API standard functionality is made available, we can revert or change wherever the new function is used. Adding/removing all these skips has some cost, too, and I would prefer to with or at least note any other JAX incompatibilities while we're still working on converting a function rather than having to come back later. @lucascolley @jakevdp anything to add / change? |
Late here but responding to #20085 (comment)
I suspect what you have in mind is the special case when the size of |
@jakevdp would you be willing to open a PR so we have something concrete to discuss? I'd be happy to contribute to the branch by using the function in stats functions and showing that it allows us to remove test skips without performance costs to non-JAX arrays. |
Thanks for the answer @jakevdp!
No, not really. To me there's a fundamental difference between the semantics of a function, and implementation details of it. For I would not call the requirement that
It'd be moving the "if boolean indices, then use
Note also that the error message for the boolean case is misleading, since it recommends a
It would be useful to have an improved error message here, and mention the boolean indexing case in https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html. I'm actually curious what your preference here is. It seems like the |
Reference issue
Towards gh-18867
What does this implement/fix?
First steps on JAX support. To-do:
Additional information
Can do the same for
dask.array
once the problems are fixed over at data-apis/array-api-compat#89.