Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation time on GPU is proportional to batch size for grad of vmapped Cholesky solve #21313

Open
vallis opened this issue May 20, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@vallis
Copy link

vallis commented May 20, 2024

Description

The problem is with the grad of the mean of a vmapped Cholesky solution . If I define

def func(pars):
  ftf = fmat @ jax.numpy.diag(pars**2) @ fmat.T + one
  cf = jax.scipy.linalg.cho_factor(ftf)
  b = jax.scipy.linalg.cho_solve(cf, ones)
  return b.mean()

and then transform/compile

jvg = jax.value_and_grad(lambda pars: jax.vmap(func)(pars).mean())
pars = jax.random.normal(jax.random.PRNGKey(0), (nbatch,2*ngp,))
jjvg = jax.jit(jvg).lower(pars).compile()

I find that the compilation time grows with nbatch. For instance nbatch, time(s) = [16,0.532], [32,0.507], [64,0.516], [128,0.580], [256,0.652], [512,0.822], [1024,1.7], [2048,2.75] for the example matrices listed below.

What's happening here?

To run this example you need matrices such as

nobs, ngp = 256, 64
t = np.linspace(0, 1, nobs)
f = np.arange(1, ngp + 1, dtype=np.float64)

fmat = np.zeros((nobs, 2*ngp), dtype=np.float64)
fmat[:,  ::2] = np.sin(2.0 * jnp.pi * f * t[:,np.newaxis])
fmat[:, 1::2] = np.cos(2.0 * jnp.pi * f * t[:,np.newaxis])

one, ones = jax.numpy.identity(nobs, dtype=np.float64), jax.numpy.ones(nobs, dtype=np.float64)

System info (python version, jaxlib version, accelerator, etc.)

JAX 0.4.26, CUDA 12.2 and driver 535.104.05, Nvidia V100. Python 3.10.12 on Linux (Colab)
@vallis vallis added the bug Something isn't working label May 20, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented May 20, 2024

Thanks for the report. I'm not sure what's going on, but it seems others are also noticing this: https://stackoverflow.com/questions/78486071/why-does-jax-compilation-time-grow-with-vmap-batch-size

@vallis
Copy link
Author

vallis commented May 20, 2024

Thanks @jakevdp; both queries are from me :) In this case, note also that the Cholesky without the grad compiles in constant time, so it must be something about the high-level gradient algorithm for Cholesky.

@jakevdp
Copy link
Collaborator

jakevdp commented May 20, 2024

I can repro on a Colab A100; thought it somehow might have to do with constant folding but even passing fmat as an argument and defining one and ones in function I still see the batch-dependent compile time

@vallis
Copy link
Author

vallis commented May 20, 2024

Interestingly, using your make_hlo (which I just found in #7949) shows that the XLA code for, say, nbatch = 64 and nbatch = 512 is essentially the same, except for 64 -> 512.

Would this mean that the problem is at the LLVM level? (Or another Nvidia representation?)

@vallis
Copy link
Author

vallis commented May 21, 2024

Another clue is that the linear compilation time happens also for a function that's already written with the extra batch dimension instead of being vmapped. So the problem must be the batched grad Cholesky.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants