Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Tree attention about Speculative Decoding #3960

Open
yukavio opened this issue Apr 10, 2024 · 4 comments
Open

[Feature]: Tree attention about Speculative Decoding #3960

yukavio opened this issue Apr 10, 2024 · 4 comments

Comments

@yukavio
Copy link

yukavio commented Apr 10, 2024

馃殌 The feature, motivation and pitch

I want to implement tree attention for vllm mentioned in RoadMap. But I don鈥檛 know whether I should implement it based on paged-attention kernel implemented in vllm or FlashInfer due to I found we plan to replace this kernel in this PR.

Alternatives

No response

Additional context

No response

@cadedaniel
Copy link
Collaborator

cadedaniel commented Apr 10, 2024

Thanks for your interest in contributing! FYI tree attention is a bit complicated to implement with non-contiguous KV cache, since intra-block attention masking has not been implemented anywhere AFAIK. We can get around this by limiting vLLM to block size of 1, but this makes it difficult to optimize latency of verification as we limit the allowed vLLM configuration space.

The way I'd recommend going about this is to implement intra-block attention masking first, then integrate it with vLLM. This is the surefire way to obtain the best latency reduction possible in vLLM. The steps as follows:

  • (Kernel) Support attention mask inside single block (M)
  • (Worker) Support attention mask in Worker API (S-M)
  • (Spec Decode Framework) Propose, score, and verify top-k candidates (M) (e.g. implement replacement for this)
  • (Spec Decode Framework) Defragment accepted KV blocks (S-M)

After the remaining open sourcing work is complete, I'll add some documentation for this.

More background information here: https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8

@reyoung
Copy link

reyoung commented Apr 10, 2024

@cadedaniel @yukavio

Tree attention mechanisms can also be utilized to generate multiple outcomes from the same prompt by varying the seeds.

This approach is an effective strategy to ensure the stability of results produced by Large Language Models (LLMs). For instance, when employing an LLM as a scoring tool to derive metrics, one could sample the LLM's outputs multiple times. By averaging these samples, a more reliable result can be obtained.


This feature might become available following the implementation of tree attention mechanisms.

@yukavio
Copy link
Author

yukavio commented Apr 10, 2024

@cadedaniel
Thanks for your reply. I have read your document and it seems that the key to the problem is that each token in the score phase requires a loop and calculation of the entire kv-cache.
I think this problem can be solved by storing all pre-score tokens for a certain seq under the same adjacent address, instead of treating them as different seqs after expansion. In this way, we can perform calculations efficiently through tensor-core with a specific attention mask.
But in this way, we should organize the pre-score token in one sequence (left in img) instead of multiple sequences (right in img).
image
If you think this way of organizing pre-score tokens is appropriate, I can implement the tensor-core cuda kernel with tree attention mask.

@cadedaniel
Copy link
Collaborator

@yukavio you should talk with @LiuXiaoxuanPKU , who is adding MQA scoring to vLLM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants