Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-cuda.so is 90mb with -arch=all #7156

Open
jart opened this issue May 9, 2024 · 1 comment
Open

ggml-cuda.so is 90mb with -arch=all #7156

jart opened this issue May 9, 2024 · 1 comment

Comments

@jart
Copy link
Contributor

jart commented May 9, 2024

The CUDA implementation for GGML_OP_FLASH_ATTN_EXT is as large as the rest of ggml-cuda combined.

master jart@luna:~/llama.cpp$ ls -Shal ggml-cuda/*.o
-rw-rw-r-- 1 jart jart 3.9M May  8 19:37 ggml-cuda/fattn.o
-rw-rw-r-- 1 jart jart 2.4M May  8 19:37 ggml-cuda/mmvq.o
-rw-rw-r-- 1 jart jart 335K May  8 19:37 ggml-cuda/mmq.o
-rw-rw-r-- 1 jart jart 316K May  8 19:37 ggml-cuda/binbcast.o
-rw-rw-r-- 1 jart jart 265K May  8 19:37 ggml-cuda/convert.o
-rw-rw-r-- 1 jart jart 197K May  8 19:37 ggml-cuda/softmax.o
-rw-rw-r-- 1 jart jart 193K May  8 19:37 ggml-cuda/cpy.o
-rw-rw-r-- 1 jart jart 143K May  8 19:37 ggml-cuda/dmmv.o
-rw-rw-r-- 1 jart jart 121K May  8 19:37 ggml-cuda/getrows.o
-rw-rw-r-- 1 jart jart 113K May  8 19:37 ggml-cuda/norm.o
-rw-rw-r-- 1 jart jart 109K May  8 19:37 ggml-cuda/rope.o
-rw-rw-r-- 1 jart jart  96K May  8 19:37 ggml-cuda/unary.o
-rw-rw-r-- 1 jart jart  85K May  8 19:37 ggml-cuda/im2col.o
-rw-rw-r-- 1 jart jart  72K May  8 19:37 ggml-cuda/argsort.o
-rw-rw-r-- 1 jart jart  71K May  8 19:37 ggml-cuda/pool2d.o
-rw-rw-r-- 1 jart jart  67K May  8 19:37 ggml-cuda/acc.o
-rw-rw-r-- 1 jart jart  67K May  8 19:37 ggml-cuda/alibi.o
-rw-rw-r-- 1 jart jart  66K May  8 19:37 ggml-cuda/upscale.o
-rw-rw-r-- 1 jart jart  66K May  8 19:37 ggml-cuda/concat.o
-rw-rw-r-- 1 jart jart  66K May  8 19:37 ggml-cuda/tsembd.o
-rw-rw-r-- 1 jart jart  66K May  8 19:37 ggml-cuda/diagmask.o
-rw-rw-r-- 1 jart jart  66K May  8 19:37 ggml-cuda/sumrows.o
-rw-rw-r-- 1 jart jart  66K May  8 19:37 ggml-cuda/pad.o
-rw-rw-r-- 1 jart jart  65K May  8 19:37 ggml-cuda/arange.o
-rw-rw-r-- 1 jart jart  65K May  8 19:37 ggml-cuda/clamp.o
-rw-rw-r-- 1 jart jart  65K May  8 19:37 ggml-cuda/scale.o
-rw-rw-r-- 1 jart jart  65K May  8 19:37 ggml-cuda/quantize.o

The heaviest function is this one:

// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
__launch_bounds__(nwarps*WARP_SIZE, 1)
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,

GPU support for flash attention can't be included in llamafile because we deal with a 4GB limit on Windows.

For comparison, in December ggml-cuda.so built with -march=all was 12mb. By February is was 16mb. By April it was 50mb. Now it's 90gb. On my project we've already started using gzip to compress the ggml-cuda dso. We've also reduced our support vector to -arch=all-major. Everything that can be done is being done on our end, since I'd like to be able to include everything if possible. However this op seems like it could benefit from a refactoring.

@JohannesGaessler
Copy link
Collaborator

By February is was 16mb. By April it was 50mb. Now it's 90gb.

I assume this is simply a typo and you mean 90mb.

We've also reduced our support vector to -arch=all-major.

When we (slaren, a user, and me) tested compiling for different CUDA architectures (months ago) we found that there is no measurable performance difference between compiling for the minimum needed CUDA architecture and the actual CUDA arch of the GPU. So assuming you use CUDA 12 it should be sufficient to compile for CUDA architectures 5.2, 6.0, 6.1, and 7.0 with the current code.

Everything that can be done is being done on our end, since I'd like to be able to include everything if possible. However this op seems like it could benefit from a refactoring.

The reasons why the FlashAttention kernel needs so much space are because

  1. it is simply a large kernel that does many things at once in order to avoid having to write the KQ matrix to VRAM and
  2. because it makes heavy use of templating to compile many different versions of this kernel so that the compiler can optimize the code for specific combinations of head sizes and batch sizes.

The first reason is I think fundamentally unavoidable. The second reason can only be avoided if you accept a significant performance penalty or reduce the number of cases covered by the kernel. Intuitively I would think that a kernel without templating would be at least 2x slower. What you could do on your end to reduce the file size without performance penalties is to compile the kernel only for the head size of the model with which you package the code; all other head sizes are never going to be used anyways. In a similar manner you could compile only those kernels for quantized data that match the quantization format of the packaged model to reduce the file size for mmq.so, mmvq.so, and dmmv.so.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants