-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies #4696
[Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies #4696
Changes from 9 commits
16fd11a
e9dc917
8112e0d
57fd939
803572e
b9b38c0
e0691a6
bf498b0
b352246
a868884
49d994d
e95d660
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import argparse | ||
import cProfile | ||
import pstats | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
# A very long prompt, total number of tokens is about 15k. | ||
LONG_PROMPT = ["You are an expert in large language models, aren't you?" | ||
] * 1000 | ||
LONG_PROMPT = ' '.join(LONG_PROMPT) | ||
|
||
|
||
def main(args): | ||
llm = LLM( | ||
model=args.model, | ||
enforce_eager=True, | ||
enable_prefix_caching=True, | ||
tensor_parallel_size=args.tensor_parallel_size, | ||
use_v2_block_manager=args.use_v2_block_manager, | ||
) | ||
|
||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) | ||
profiler = cProfile.Profile() | ||
|
||
print("------warm up------") | ||
for i in range(3): | ||
output = llm.generate(LONG_PROMPT, sampling_params) | ||
print(output[0].outputs[0].text) | ||
|
||
print("------start generating------") | ||
for i in range(3): | ||
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', | ||
globals(), locals()) | ||
|
||
# analyze the runtime of hashing function | ||
stats = pstats.Stats(profiler) | ||
stats.sort_stats('cumulative') | ||
total_time = 0 | ||
total_calls = 0 | ||
for func in stats.stats: | ||
if 'hash_of_block' in func[2]: | ||
total_time = stats.stats[func][3] | ||
total_calls = stats.stats[func][0] | ||
percentage = (total_time / stats.total_tt) * 100 | ||
print(f"Hashing took {total_time:.2f} seconds," | ||
f"{percentage:.2f}% of the total runtime.") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description='Benchmark the performance of hashing function in' | ||
'automatic prefix caching.') | ||
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') | ||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) | ||
parser.add_argument('--output-len', type=int, default=10) | ||
parser.add_argument('--enable-prefix-caching', | ||
action='store_true', | ||
help='enable prefix caching') | ||
parser.add_argument('--use-v2-block-manager', | ||
action='store_true', | ||
help='Use BlockSpaceMangerV2') | ||
args = parser.parse_args() | ||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import time | ||
from typing import List, Optional, Union | ||
from typing import List, Optional, Tuple, Union | ||
|
||
from vllm.lora.request import LoRARequest | ||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, | ||
|
@@ -75,7 +75,7 @@ def __init__( | |
self, | ||
request_id: str, | ||
prompt: str, | ||
prompt_token_ids: List[int], | ||
prompt_token_ids: Union[List[int], Tuple[int, ...]], | ||
prompt_logprobs: Optional[PromptLogprobs], | ||
outputs: List[CompletionOutput], | ||
finished: bool, | ||
|
@@ -84,7 +84,7 @@ def __init__( | |
) -> None: | ||
self.request_id = request_id | ||
self.prompt = prompt | ||
self.prompt_token_ids = prompt_token_ids | ||
self.prompt_token_ids = list(prompt_token_ids) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to do this? maybe it will be better to change tests instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test case corresponding to this change is: for req_output in req_outputs:
...
prompt_ids = req_output.prompt_token_ids
...
for sample in req_output.outputs:
...
req_sample_output_ids.append(prompt_ids + output_ids)
...
... And the code I guess processing vllm's output using code like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I think it would be fine to keep it as tuple and cast it to list in the test code tbh |
||
self.prompt_logprobs = prompt_logprobs | ||
self.outputs = outputs | ||
self.finished = finished | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import copy | ||
import enum | ||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING, Dict, List, Optional, Union | ||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union | ||
|
||
from vllm.block import LogicalTokenBlock | ||
from vllm.lora.request import LoRARequest | ||
|
@@ -112,13 +112,13 @@ class SequenceData: | |
|
||
def __init__( | ||
self, | ||
prompt_token_ids: List[int], | ||
prompt_token_ids: Union[List[int], Tuple[int, ...]], | ||
output_token_ids: Optional[List[int]] = None, | ||
) -> None: | ||
if output_token_ids is None: | ||
output_token_ids = [] | ||
|
||
self.prompt_token_ids = prompt_token_ids | ||
self.prompt_token_ids: Tuple[int, ...] = tuple(prompt_token_ids) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comment this should not be changed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also consider making it private attr
? |
||
self.output_token_ids = output_token_ids | ||
self.cumulative_logprob = 0.0 | ||
# The number of tokens that are computed (that run against the model). | ||
|
@@ -139,7 +139,18 @@ def get_output_len(self) -> int: | |
return len(self.output_token_ids) | ||
|
||
def get_token_ids(self) -> List[int]: | ||
return self.prompt_token_ids + self.output_token_ids | ||
return list(self.prompt_token_ids) + self.output_token_ids | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove tuple input and remove conversion! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. For performance we need to avoid tuple -> list conversion. How about storing the list version as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I think that sounds good to me. (it should be fine since prompt tokens are not going to be changed) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert this part? (ping me when it is done!) |
||
|
||
def get_prefix_token_ids( | ||
self, num_tokens: int | ||
) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: | ||
"""Get prefix tokens, and make the return value hashable""" | ||
prompt_length = len(self.prompt_token_ids) | ||
if num_tokens > prompt_length: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when does this happen? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This happens when calculating hashes for both the user input (i.e. the prompt tokens) and the LLM-generated output (i.e. output tokens ). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can it happen under normal circumstance or is it only for recomputation case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it will happen for normal cases (inside function |
||
return (self.prompt_token_ids, | ||
tuple(self.output_token_ids[:num_tokens - prompt_length])) | ||
else: | ||
return (self.prompt_token_ids[:num_tokens], None) | ||
|
||
def get_num_computed_tokens(self) -> int: | ||
"""Return the number of prefill tokens that are already computed.""" | ||
|
@@ -174,7 +185,7 @@ def get_last_token_id(self) -> int: | |
return self.prompt_token_ids[-1] | ||
return self.output_token_ids[-1] | ||
|
||
def get_prompt_token_ids(self) -> List[int]: | ||
def get_prompt_token_ids(self) -> Tuple[int, ...]: | ||
return self.prompt_token_ids | ||
|
||
def get_output_token_ids(self) -> List[int]: | ||
|
@@ -245,14 +256,9 @@ def get_output_text_to_return(self, buffer_length: int): | |
self.output_text) | ||
|
||
def hash_of_block(self, logical_idx: int) -> int: | ||
# TODO This can produce incorrect hash when block size > prompt size | ||
|
||
# Compute the number of tokens in the sequence | ||
# TODO: The current hashing function is O(L^2). We should optimize | ||
# this in the future. | ||
num_tokens = self.num_hashed_tokens_of_block(logical_idx) | ||
return hash( | ||
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id)) | ||
hashed_tokens = self.data.get_prefix_token_ids(num_tokens) | ||
return hash((hashed_tokens, self.lora_int_id)) | ||
|
||
def num_hashed_tokens_of_block(self, logical_idx: int): | ||
return logical_idx * self.block_size + self.block_size | ||
|
@@ -306,7 +312,7 @@ def get_output_len(self) -> int: | |
def get_token_ids(self) -> List[int]: | ||
return self.data.get_token_ids() | ||
|
||
def get_prompt_token_ids(self) -> List[int]: | ||
def get_prompt_token_ids(self) -> Tuple[int, ...]: | ||
return self.data.get_prompt_token_ids() | ||
|
||
def get_last_token_id(self) -> int: | ||
|
@@ -433,7 +439,7 @@ def prompt(self) -> str: | |
return next(iter(self.seqs_dict.values())).prompt | ||
|
||
@property | ||
def prompt_token_ids(self) -> List[int]: | ||
def prompt_token_ids(self) -> Tuple[int, ...]: | ||
# All sequences in the group should have the same prompt. | ||
# We use the prompt of an arbitrary sequence. | ||
return next(iter(self.seqs_dict.values())).data.prompt_token_ids | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove tuple input?