-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compilation time on GPU is proportional to batch size for grad of vmapped Cholesky solve #21313
Comments
Thanks for the report. I'm not sure what's going on, but it seems others are also noticing this: https://stackoverflow.com/questions/78486071/why-does-jax-compilation-time-grow-with-vmap-batch-size |
Thanks @jakevdp; both queries are from me :) In this case, note also that the Cholesky without the grad compiles in constant time, so it must be something about the high-level gradient algorithm for Cholesky. |
I can repro on a Colab A100; thought it somehow might have to do with constant folding but even passing |
Interestingly, using your Would this mean that the problem is at the LLVM level? (Or another Nvidia representation?) |
Another clue is that the linear compilation time happens also for a function that's already written with the extra batch dimension instead of being |
Description
The problem is with the grad of the mean of a vmapped Cholesky solution . If I define
and then transform/compile
I find that the compilation time grows with
nbatch
. For instancenbatch, time(s) = [16,0.532], [32,0.507], [64,0.516], [128,0.580], [256,0.652], [512,0.822], [1024,1.7], [2048,2.75]
for the example matrices listed below.What's happening here?
To run this example you need matrices such as
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: