-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BF16 matmul slower than F32 matmul on T4 GPU #21212
Comments
Thanks for the report! Here's a ref to the T4 architecture spec: https://images.nvidia.com/aem-dam/Solutions/design-visualization/technologies/turing-architecture/NVIDIA-Turing-Architecture-Whitepaper.pdf T4 doesn't support bfloat16, but JAX (via the XLA GPU compiler) should be falling back to float32. The fact that the result is appreciably slower than native float32 may indicate a bug in the XLA GPU compiler. I'd suggest reporting at https://github.com/openxla/xla |
Replied on the OpenXLA bug. |
Description
BF16 matmul appears to be slower than F32 matmul on T4. From my test, BF16 appears to be half the speed. I believe this is a bug and bf16 should be the same speed (or possibly better) than f32.
You can repro in a T4 colab with the following:
This results in the following output:
Note how bf16 is much slower than f32. (side note: I also see that bf16 is way slower than f16, but my understanding is that it's because t4 doesn't support bf16, so JAX alters the computation to use f32).
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: