Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Support for __array__() #9584

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Conversation

gmarkall
Copy link
Member

@gmarkall gmarkall commented May 23, 2024

Pushing this up early to request comments / feedback - I think this is one way to solve rapidsai/cudf#15694, with broad applicability to other use cases - this should allow Numba-jitted functions to be called on anything that has an __array__() function, much like CUDA kernels can be called on anything that supports the CUDA Array Interface.

The implementation handles typing and unboxing by falling back to using __array__() only if treating the object as an ndarray fails - therefore, it should not impact the fastest path through dispatch - the code additions here are only executed in cases where the dispatch would otherwise have failed.

One minimal test case is added to demonstrate the idea. Additional tests I think would be needed:

  • Testing with and without NRT enabled, though I'm not sure we still have a way to disable NRT. If we can't disable NRT anymore, then there's some dead code like numba_adapt_ndarray that can be deleted.
  • Testing with a mix of __array__() and ndarray arguments.
  • Testing that this doesn't cause any errors in reference counting / memory leaks (not that I expect it should, but having a test to prove it would be nice)

cc @brandon-b-miller @AjayThorve

Another benefit of this PR - the following:

import pandas as pd
from numba import njit

s1 = pd.Series([1, 2, 3])
s2 = pd.Series([4, 5, 6])

df = pd.DataFrame([s1, s2])

@njit
def f(x):
    return x + 1

print(f(s1))
print(f(df))

presently fails on main with an error ending in:

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class 'pandas.core.series.Series'>

but with this PR, gives:

$ python repro.py 
[2 3 4]
[[2 3 4]
 [5 6 7]]

In general, it becomes possible to pass pandas objects directly to jitted functions.

@gmarkall
Copy link
Member Author

Also - I can't find a feature request issue to link this too - I'm surprised if this hasn't been asked-for directly at some point though 🤷

@gmarkall gmarkall added the skip_release_notes Skip towncrier requirement label May 23, 2024
@gmarkall
Copy link
Member Author

The CI failure is triggered by the latent bug #9585.

@@ -687,7 +687,10 @@ def typeof_pyval(self, val):
try:
tp = typeof(val, Purpose.argument)
except ValueError:
tp = types.pyobject
if hasattr(val, '__array__'):
tp = typeof(val.__array__(), Purpose.argument)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the CI failure calls to attention that I need an additional try: ... except ValueError: ... around this, but that wouldn't be sufficient to make this work, due to #9585 - the NotImplementedError it produces would not be caught by except ValueError.

@gmarkall
Copy link
Member Author

Notes from triage meeting (very quickly written, will tidy up later):

  • This should make it possible to pass a pandas series or dask array to a jitted function
  • If there's a performance issue due to the use of array function, (e.g. CuPY array converting to NumPy array), then the library doing the inefficient thing should be responsible for any warnings
  • Would we want a warning to be emitted when this is done?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
2 - In Progress skip_release_notes Skip towncrier requirement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant